Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-28169][SQL] Convert scan predicate condition to CNF #28805

Closed
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ trait PredicateHelper extends Logging {
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
protected def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
protected def conjunctiveNormalForm(
condition: Expression,
groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
Expand All @@ -226,8 +228,8 @@ trait PredicateHelper extends Logging {
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = groupExpressionsByQualifier(resultStack.pop())
val left = groupExpressionsByQualifier(resultStack.pop())
val right = groupExpsFunc(resultStack.pop())
val left = groupExpsFunc(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
Expand All @@ -249,8 +251,36 @@ trait PredicateHelper extends Logging {
resultStack.top
}

private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
/**
* Convert an expression to conjunctive normal form when pushing predicates through Join,
* when expand predicates, we can group by the qualifier avoiding generate unnecessary
* expression to control the length of final result since there are multiple tables.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought, the method name conjunctiveNormalFormAndGroupExpsByQualifier is too long and the And is weird.
How about changing to CNFWithGroupExpressionsByQualifier?

conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit format:

    conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
        expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)

}

/**
* Convert an expression to conjunctive normal form for predicate pushdown and partition pruning.
* When expanding predicates, this method groups expressions by their references for reducing
* the size of pushed down predicates and corresponding codegen. In partition pruning strategies,
* we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's
* references is subset of partCols, if we combine expressions group by reference when expand
* predicate of [[Or]], it won't impact final predicate pruning result since
* [[splitConjunctivePredicates]] won't split [[Or]] expression.
*
* @param condition condition need to be converted
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = {
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about changing to CNFWithGroupExpressionsByReference?

conjunctiveNormalForm(condition, (expressions: Seq[Expression]) =>
expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq)
}

wangyum marked this conversation as resolved.
Show resolved Hide resolved
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, joinType, Some(joinCondition), hint)
if canPushThrough(joinType) =>
val predicates = conjunctiveNormalForm(joinCondition)
val predicates = conjunctiveNormalFormAndGroupExpsByQualifier(joinCondition)
if (predicates.isEmpty) {
j
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe

// Check CNF conversion with expected expression, assuming the input has non-empty result.
private def checkCondition(input: Expression, expected: Expression): Unit = {
val cnf = conjunctiveNormalForm(input)
val cnf = conjunctiveNormalFormAndGroupExpsByQualifier(input)
assert(cnf.nonEmpty)
val result = cnf.reduceLeft(And)
assert(result.semanticEquals(expected))
Expand Down Expand Up @@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
if (maxCount < 36) {
assert(conjunctiveNormalForm(input).isEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).isEmpty)
} else {
assert(conjunctiveNormalForm(input).nonEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).nonEmpty)
}
if (maxCount < 9) {
assert(conjunctiveNormalForm(input2).isEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).isEmpty)
} else {
assert(conjunctiveNormalForm(input2).nonEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).nonEmpty)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ case class FileSourceScanExec(
private def isDynamicPruningFilter(e: Expression): Boolean =
e.find(_.isInstanceOf[PlanExpression[_]]).isDefined

@transient private lazy val selectedPartitions: Array[PartitionDirectory] = {
@transient lazy val selectedPartitions: Array[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
val ret =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ import org.apache.spark.sql.types.StructType
* its underlying [[FileScan]]. And the partition filters will be removed in the filters of
* returned logical plan.
*/
private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
private[sql] object PruneFileSourcePartitions
extends Rule[LogicalPlan] with PredicateHelper {

private def getPartitionKeyFiltersAndDataFilters(
sparkSession: SparkSession,
Expand Down Expand Up @@ -87,8 +88,12 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] {
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to group by reference here? Can you explain the rationale a bit more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to group by reference here? Can you explain the rationale a bit more?

I have discuss this problem in #28733 (comment)

The demo in that comment can show why for partition pruning we need to use reference why not qualifier

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see the rationale. What if we don't do the group by and simply apply the CNF conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't see the rationale. What if we don't do the group by and simply apply the CNF conversion?

Such a case

TBL: test
PARTITION COLS : dt

SELECT * FROM test where (dt = 2 and  id < 100 and id > 20 ) or dt = 3

if we don't group by reference, the condition (dt = 2 and id < 100 and id > 20 ) or dt = 3 will be converted to

(dt = 3 or dt = 2) and (dt = 3 or id < 100) and (dt = 3 or id > 20)

but we know that only (dt = 3 or dt = 2) can be predicated as partition pruning, we can combine id < 100 and id > 20 by grouByReference, and return as

(dt = 3 or dt = 2) and (dt = 3 or (id < 100 and  id > 20))

In other word , since in final strategies of partition pruning, we partition predicate filter by judge if it's references is subset of partCols, if we combine condition group by reference, Here we result with or expression and or can't be split by splitConjunctivePredicates, it won't impact final push down result.

It's ok to just simply apply CNF rule, but group by references can avoid generate unnecessary expression to control the length of generated final exprs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, in brief:

  • if we are pushing predicates through Join, then we can group by the qualifier since there are multiple tables.
  • if we are pushing predicates to underlying file source of a table, we should group by the reference.

val finalPredicates = if (predicates.nonEmpty) predicates else filters
val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output)
fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates,
logicalRelation.output)

if (partitionKeyFilters.nonEmpty) {
val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
val prunedFsRelation =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ import org.apache.hadoop.hive.common.StatsSetupConst

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, ExternalCatalogUtils, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, SubqueryExpression}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, ExpressionSet, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -41,7 +41,7 @@ import org.apache.spark.sql.internal.SQLConf
* TODO: merge this with PruneFileSourcePartitions after we completely make hive as a data source.
*/
private[sql] class PruneHiveTablePartitions(session: SparkSession)
extends Rule[LogicalPlan] with CastSupport {
extends Rule[LogicalPlan] with CastSupport with PredicateHelper {

override val conf: SQLConf = session.sessionState.conf

Expand Down Expand Up @@ -103,7 +103,9 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession)
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation)
if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty =>
val partitionKeyFilters = getPartitionKeyFilters(filters, relation)
val predicates = conjunctiveNormalFormAndGroupExpsByReference(filters.reduceLeft(And))
val finalPredicates = if (predicates.nonEmpty) predicates else filters
val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation)
if (partitionKeyFilters.nonEmpty) {
val newPartitions = prunePartitions(relation, partitionKeyFilters)
val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ package org.apache.spark.sql.hive.execution

import org.scalatest.Matchers._

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.{CatalogFileIndex, HadoopFsRelation, LogicalRelation, PruneFileSourcePartitions}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions.broadcast
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType

class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase {

override def format: String = "parquet"

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("PruneFileSourcePartitions", Once, PruneFileSourcePartitions) :: Nil
Expand Down Expand Up @@ -108,4 +108,10 @@ class PruneFileSourcePartitionsSuite extends QueryTest with SQLTestUtils with Te
}
}
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: FileSourceScanExec => p
}.get.selectedPartitions.length
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.execution.SparkPlan

class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase {

override def format(): String = "hive"

object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("PruneHiveTablePartitions", Once,
EliminateSubqueryAliases, new PruneHiveTablePartitions(spark)) :: Nil
}

test("SPARK-15616 statistics pruned after going throuhg PruneHiveTablePartitions") {
test("SPARK-15616: statistics pruned after going through PruneHiveTablePartitions") {
withTable("test", "temp") {
sql(
s"""
Expand All @@ -54,4 +54,10 @@ class PruneHiveTablePartitionsSuite extends QueryTest with SQLTestUtils with Tes
Optimize.execute(analyzed2).stats.sizeInBytes)
}
}

override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: HiveTableScanExec => p
}.get.prunedPartitions.size
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils

abstract class PrunePartitionSuiteBase extends QueryTest with SQLTestUtils with TestHiveSingleton {

protected def format: String

test("SPARK-28169: Convert scan predicate condition to CNF") {
withTempView("temp") {
withTable("t") {
sql(
s"""
|CREATE TABLE t(i INT, p STRING)
|USING $format
|PARTITIONED BY (p)""".stripMargin)

spark.range(0, 1000, 1).selectExpr("id as col")
.createOrReplaceTempView("temp")
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved

for (part <- Seq(1, 2, 3, 4)) {
sql(
s"""
|INSERT OVERWRITE TABLE t PARTITION (p='$part')
|SELECT col FROM temp""".stripMargin)
}

assertPrunedPartitions(
"SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)", 2)
assertPrunedPartitions(
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (i = 1 OR p = '2')", 4)
assertPrunedPartitions(
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '3' AND i = 3 )", 2)
assertPrunedPartitions(
"SELECT * FROM t WHERE (p = '1' AND i = 2) OR (p = '2' OR p = '3')", 3)
assertPrunedPartitions(
"SELECT * FROM t", 4)
assertPrunedPartitions(
"SELECT * FROM t WHERE p = '1' AND i = 2", 1)
assertPrunedPartitions(
"""
|SELECT i, COUNT(1) FROM (
|SELECT * FROM t WHERE p = '1' OR (p = '2' AND i = 1)
|) tmp GROUP BY i
""".stripMargin, 2)
}
}
}

protected def assertPrunedPartitions(query: String, expected: Long): Unit = {
val plan = sql(query).queryExecution.sparkPlan
assert(getScanExecPartitionSize(plan) == expected)
}

protected def getScanExecPartitionSize(plan: SparkPlan): Long
}