-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-28169][SQL] Convert scan predicate condition to CNF #28805
Changes from 24 commits
3356bac
346a1b4
250c7b3
15d62be
39e85ad
d8f7c9e
8856453
697a3a9
7e8319e
3734866
b253af3
478a7a8
603660b
69f1763
2f576fa
e71c45c
94609c8
326fb49
9322ae6
4a2adcd
0e2579d
270324e
219f200
f21cf43
35b5813
3df019a
1b8466e
e2777c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -211,7 +211,9 @@ trait PredicateHelper extends Logging { | |
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions | ||
* exceeds threshold on converting `Or`, `Seq.empty` is returned. | ||
*/ | ||
protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = { | ||
protected def conjunctiveNormalForm( | ||
condition: Expression, | ||
groupExpsFunc: Seq[Expression] => Seq[Expression] = _.toSeq): Seq[Expression] = { | ||
val postOrderNodes = postOrderTraversal(condition) | ||
val resultStack = new mutable.Stack[Seq[Expression]] | ||
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount | ||
|
@@ -226,8 +228,8 @@ trait PredicateHelper extends Logging { | |
// For each side, there is no need to expand predicates of the same references. | ||
// So here we can aggregate predicates of the same qualifier as one single predicate, | ||
// for reducing the size of pushed down predicates and corresponding codegen. | ||
val right = groupExpressionsByQualifier(resultStack.pop()) | ||
val left = groupExpressionsByQualifier(resultStack.pop()) | ||
val right = groupExpsFunc(resultStack.pop()) | ||
val left = groupExpsFunc(resultStack.pop()) | ||
// Stop the loop whenever the result exceeds the `maxCnfNodeCount` | ||
if (left.size * right.size > maxCnfNodeCount) { | ||
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + | ||
|
@@ -249,8 +251,36 @@ trait PredicateHelper extends Logging { | |
resultStack.top | ||
} | ||
|
||
private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = { | ||
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq | ||
/** | ||
* Convert an expression to conjunctive normal form when pushing predicates through Join, | ||
* when expand predicates, we can group by the qualifier avoiding generate unnecessary | ||
* expression to control the length of final result since there are multiple tables. | ||
* @param condition condition need to be convert | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: convert -> converted |
||
* @return expression seq in conjunctive normal form of input expression, if length exceeds | ||
* the threshold [[SQLConf.MAX_CNF_NODE_COUNT]] or length != 1, return empty Seq | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: This
|
||
*/ | ||
def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On second thought, the method name |
||
conjunctiveNormalForm(condition, | ||
(expressions: Seq[Expression]) => | ||
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit format:
|
||
} | ||
|
||
/** | ||
* Convert an expression to conjunctive normal form when pushing predicates for partition pruning, | ||
* when expand predicates, we can group by the reference avoiding generate unnecessary expression | ||
* to control the length of final result since here we just have one table. In partition pruning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: How about rephrasing it like this?
|
||
* strategies, we split filters by [[splitConjunctivePredicates]] and partition filters by judging | ||
* if it's references is subset of partCols, if we combine expressions group by reference when | ||
* expand predicate of [[Or]], it won't impact final predicate pruning result since | ||
* [[splitConjunctivePredicates]] won't split [[Or]] expression. | ||
* @param condition condition need to be convert | ||
* @return expression seq in conjunctive normal form of input expression, if length exceeds | ||
* the threshold [[SQLConf.MAX_CNF_NODE_COUNT]] or length != 1, return empty Seq | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
*/ | ||
def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = { | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about changing to |
||
conjunctiveNormalForm(condition, | ||
(expressions: Seq[Expression]) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit format:
|
||
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) | ||
} | ||
|
||
wangyum marked this conversation as resolved.
Show resolved
Hide resolved
|
||
/** | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType | |
* its underlying [[FileScan]]. And the partition filters will be removed in the filters of | ||
* returned logical plan. | ||
*/ | ||
private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { | ||
private[sql] object PruneFileSourcePartitions | ||
extends Rule[LogicalPlan] with PredicateHelper { | ||
|
||
private def getPartitionKeyFiltersAndDataFilters( | ||
sparkSession: SparkSession, | ||
|
@@ -87,8 +88,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { | |
_, | ||
_)) | ||
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => | ||
val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have to group by reference here? Can you explain the rationale a bit more? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I have discuss this problem in #28733 (comment) The demo in that comment can show why for partition pruning we need to use reference why not qualifier There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't see the rationale. What if we don't do the group by and simply apply the CNF conversion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Such a case
if we don't group by reference, the condition
but we know that only
In other word , since in final strategies of partition pruning, we partition predicate filter by judge if it's references is subset of partCols, if we combine condition group by reference, Here we result with It's ok to just simply apply CNF rule, but group by references can avoid generate unnecessary expression to control the length of generated final exprs. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, in brief:
|
||
val finalPredicates = if (predicates.nonEmpty) predicates else filters | ||
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( | ||
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) | ||
fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, | ||
logicalRelation.output) | ||
|
||
if (partitionKeyFilters.nonEmpty) { | ||
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) | ||
val prunedFsRelation = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,22 +19,22 @@ package org.apache.spark.sql.hive.execution | |
|
||
import org.scalatest.Matchers._ | ||
|
||
import org.apache.spark.sql.QueryTest | ||
import org.apache.spark.sql.catalyst.TableIdentifier | ||
import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
import org.apache.spark.sql.catalyst.dsl.plans._ | ||
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} | ||
import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
import org.apache.spark.sql.execution.FileSourceScanExec | ||
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions} | ||
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat | ||
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec | ||
import org.apache.spark.sql.functions.broadcast | ||
import org.apache.spark.sql.hive.test.TestHiveSingleton | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.test.SQLTestUtils | ||
import org.apache.spark.sql.types.StructType | ||
|
||
class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | ||
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase { | ||
|
||
override def format: String = "parquet" | ||
|
||
object Optimize extends RuleExecutor[LogicalPlan] { | ||
val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil | ||
|
@@ -108,4 +108,10 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te | |
} | ||
} | ||
} | ||
|
||
override def getScanExecPartitionSize(query: String): Long = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: the input can be spark plan instead of query string. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Updated |
||
sql(query).queryExecution.sparkPlan.collectFirst { | ||
case p: FileSourceScanExec => p | ||
}.get.relation.location.inputFiles.length | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not partitions number, but files number. This should be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
make FileSourceScanExec's selectedPartitions to be public like HiveTableScanExec's prunedPartitions will be ok |
||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,22 +17,21 @@ | |
|
||
package org.apache.spark.sql.hive.execution | ||
|
||
import org.apache.spark.sql.QueryTest | ||
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases | ||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
import org.apache.spark.sql.hive.test.TestHiveSingleton | ||
import org.apache.spark.sql.test.SQLTestUtils | ||
|
||
class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { | ||
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { | ||
|
||
override def format(): String = "hive" | ||
|
||
object Optimize extends RuleExecutor[LogicalPlan] { | ||
val batches = | ||
Batch("PruneHiveTablePartitions", Once, | ||
EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil | ||
} | ||
|
||
test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") { | ||
test("SPARK-15616 statistics pruned after going through PruneHiveTablePartitions") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
withTable("test", "temp") { | ||
sql( | ||
s""" | ||
|
@@ -54,4 +53,10 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes | |
Optimize.execute(analyzed2).stats.sizeInBytes) | ||
} | ||
} | ||
|
||
override def getScanExecPartitionSize(query: String): Long = { | ||
sql(query).queryExecution.sparkPlan.collectFirst { | ||
case p: HiveTableScanExec => p | ||
}.get.prunedPartitions.size | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.hive.execution | ||
|
||
import org.apache.spark.sql.QueryTest | ||
import org.apache.spark.sql.hive.HiveUtils | ||
import org.apache.spark.sql.hive.test.TestHiveSingleton | ||
import org.apache.spark.sql.test.SQLTestUtils | ||
|
||
abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton { | ||
|
||
protected def format: String | ||
|
||
test("SPARK-28169: Convert scan predicate condition to CNF") { | ||
withTempView("temp") { | ||
withTable("t") { | ||
sql( | ||
s""" | ||
|CREATE TABLE t(i INT, p STRING) | ||
|USING $format | ||
|PARTITIONED BY (p)""".stripMargin) | ||
|
||
spark.range(0, 1000, 1).selectExpr("id as col") | ||
.createOrReplaceTempView("temp") | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
for (part <- Seq(1, 2, 3, 4)) { | ||
sql( | ||
s""" | ||
|INSERT OVERWRITE TABLE t PARTITION (p='$part') | ||
|SELECT col FROM temp""".stripMargin) | ||
} | ||
|
||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t", 4) | ||
assertPrunedPartitions( | ||
"SELECT * FROM t WHERE p = '1' AND i = 2", 1) | ||
assertPrunedPartitions( | ||
""" | ||
|SELECT i, COUNT(1) FROM ( | ||
|SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1) | ||
|) tmp GROUP BY i | ||
""".stripMargin, 2) | ||
} | ||
} | ||
} | ||
|
||
protected def assertPrunedPartitions(query: String, expected: Long): Unit = { | ||
assert(getScanExecPartitionSize(query) == expected) | ||
} | ||
|
||
protected def getScanExecPartitionSize(query: String): Long | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: do we really need a default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For current usage scenarios, no, maybe we shouldn't have a default to make sure each use case we have a definite purpose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to remove this default.