diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 527618b8e2c5a..aa5cf4758564b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -202,125 +202,50 @@ trait PredicateHelper extends Logging { } /** - * Convert an expression into conjunctive normal form. - * Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form - * CNF can explode exponentially in the size of the input expression when converting [[Or]] - * clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases. - * - * @param condition to be converted into CNF. - * @return the CNF result as sequence of disjunctive expressions. If the number of expressions - * exceeds threshold on converting `Or`, `Seq.empty` is returned. + * Returns a filter that its reference is a subset of `outputSet` and it contains the maximum + * constraints from `condition`. This is used for predicate pushdown. + * When there is no such filter, `None` is returned. */ - protected def conjunctiveNormalForm( + protected def extractPredicatesWithinOutputSet( condition: Expression, - groupExpsFunc: Seq[Expression] => Seq[Expression]): Seq[Expression] = { - val postOrderNodes = postOrderTraversal(condition) - val resultStack = new mutable.Stack[Seq[Expression]] - val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount - // Bottom up approach to get CNF of sub-expressions - while (postOrderNodes.nonEmpty) { - val cnf = postOrderNodes.pop() match { - case _: And => - val right = resultStack.pop() - val left = resultStack.pop() - left ++ right - case _: Or => - // For each side, there is no need to expand predicates of the same references. - // So here we can aggregate predicates of the same qualifier as one single predicate, - // for reducing the size of pushed down predicates and corresponding codegen. - val right = groupExpsFunc(resultStack.pop()) - val left = groupExpsFunc(resultStack.pop()) - // Stop the loop whenever the result exceeds the `maxCnfNodeCount` - if (left.size * right.size > maxCnfNodeCount) { - logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " + - "The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " + - s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.") - return Seq.empty - } else { - for { x <- left; y <- right } yield Or(x, y) - } - case other => other :: Nil + outputSet: AttributeSet): Option[Expression] = condition match { + case And(left, right) => + val leftResultOptional = extractPredicatesWithinOutputSet(left, outputSet) + val rightResultOptional = extractPredicatesWithinOutputSet(right, outputSet) + (leftResultOptional, rightResultOptional) match { + case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult)) + case (Some(leftResult), None) => Some(leftResult) + case (None, Some(rightResult)) => Some(rightResult) + case _ => None } - resultStack.push(cnf) - } - if (resultStack.length != 1) { - logWarning("The length of CNF conversion result stack is supposed to be 1. There might " + - "be something wrong with CNF conversion.") - return Seq.empty - } - resultStack.top - } - - /** - * Convert an expression to conjunctive normal form when pushing predicates through Join, - * when expand predicates, we can group by the qualifier avoiding generate unnecessary - * expression to control the length of final result since there are multiple tables. - * - * @param condition condition need to be converted - * @return the CNF result as sequence of disjunctive expressions. If the number of expressions - * exceeds threshold on converting `Or`, `Seq.empty` is returned. - */ - def CNFWithGroupExpressionsByQualifier(condition: Expression): Seq[Expression] = { - conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => - expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq) - } - - /** - * Convert an expression to conjunctive normal form for predicate pushdown and partition pruning. - * When expanding predicates, this method groups expressions by their references for reducing - * the size of pushed down predicates and corresponding codegen. In partition pruning strategies, - * we split filters by [[splitConjunctivePredicates]] and partition filters by judging if it's - * references is subset of partCols, if we combine expressions group by reference when expand - * predicate of [[Or]], it won't impact final predicate pruning result since - * [[splitConjunctivePredicates]] won't split [[Or]] expression. - * - * @param condition condition need to be converted - * @return the CNF result as sequence of disjunctive expressions. If the number of expressions - * exceeds threshold on converting `Or`, `Seq.empty` is returned. - */ - def CNFWithGroupExpressionsByReference(condition: Expression): Seq[Expression] = { - conjunctiveNormalForm(condition, (expressions: Seq[Expression]) => - expressions.groupBy(e => AttributeSet(e.references)).map(_._2.reduceLeft(And)).toSeq) - } - /** - * Iterative post order traversal over a binary tree built by And/Or clauses with two stacks. - * For example, a condition `(a And b) Or c`, the postorder traversal is - * (`a`,`b`, `And`, `c`, `Or`). - * Following is the complete algorithm. After step 2, we get the postorder traversal in - * the second stack. - * 1. Push root to first stack. - * 2. Loop while first stack is not empty - * 2.1 Pop a node from first stack and push it to second stack - * 2.2 Push the children of the popped node to first stack - * - * @param condition to be traversed as binary tree - * @return sub-expressions in post order traversal as a stack. - * The first element of result stack is the leftmost node. - */ - private def postOrderTraversal(condition: Expression): mutable.Stack[Expression] = { - val stack = new mutable.Stack[Expression] - val result = new mutable.Stack[Expression] - stack.push(condition) - while (stack.nonEmpty) { - val node = stack.pop() - node match { - case Not(a And b) => stack.push(Or(Not(a), Not(b))) - case Not(a Or b) => stack.push(And(Not(a), Not(b))) - case Not(Not(a)) => stack.push(a) - case a And b => - result.push(node) - stack.push(a) - stack.push(b) - case a Or b => - result.push(node) - stack.push(a) - stack.push(b) - case _ => - result.push(node) + // The Or predicate is convertible when both of its children can be pushed down. + // That is to say, if one/both of the children can be partially pushed down, the Or + // predicate can be partially pushed down as well. + // + // Here is an example used to explain the reason. + // Let's say we have + // condition: (a1 AND a2) OR (b1 AND b2), + // outputSet: AttributeSet(a1, b1) + // a1 and b1 is convertible, while a2 and b2 is not. + // The predicate can be converted as + // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) + // As per the logical in And predicate, we can push down (a1 OR b1). + case Or(left, right) => + for { + lhs <- extractPredicatesWithinOutputSet(left, outputSet) + rhs <- extractPredicatesWithinOutputSet(right, outputSet) + } yield Or(lhs, rhs) + + // Here we assume all the `Not` operators is already below all the `And` and `Or` operators + // after the optimization rule `BooleanSimplification`, so that we don't need to handle the + // `Not` operators here. + case other => + if (other.references.subsetOf(outputSet)) { + Some(other) + } else { + None } - } - result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a0e21ed86a71e..79d00d32c9307 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -51,8 +51,7 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val excludedOnceBatches: Set[String] = Set( "PartitionPruning", - "Extract Python UDFs", - "Push CNF predicate through join") + "Extract Python UDFs") protected def fixedPoint = FixedPoint( @@ -123,8 +122,9 @@ abstract class Optimizer(catalogManager: CatalogManager) rulesWithoutInferFiltersFromConstraints: _*) :: // Set strategy to Once to avoid pushing filter every time because we do not change the // join condition. - Batch("Push CNF predicate through join", Once, - PushCNFPredicateThroughJoin) :: Nil + Batch("Push extra predicate through join", fixedPoint, + PushExtraPredicateThroughJoin, + PushDownPredicates) :: Nil } val batches = (Batch("Eliminate Distinct", Once, EliminateDistinct) :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala similarity index 59% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala index 47e9527ead7c3..0ba2ce3106061 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushCNFPredicateThroughJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PushExtraPredicateThroughJoin.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{And, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreeNodeTag /** - * Try converting join condition to conjunctive normal form expression so that more predicates may - * be able to be pushed down. + * Try pushing down disjunctive join condition into left and right child. * To avoid expanding the join condition, the join condition will be kept in the original form even * when predicate pushdown happens. */ -object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { +object PushExtraPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + + private val processedJoinConditionTag = TreeNodeTag[Expression]("processedJoinCondition") private def canPushThrough(joinType: JoinType): Boolean = joinType match { case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true @@ -38,22 +40,28 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe def apply(plan: LogicalPlan): LogicalPlan = plan transform { case j @ Join(left, right, joinType, Some(joinCondition), hint) if canPushThrough(joinType) => - val predicates = CNFWithGroupExpressionsByQualifier(joinCondition) - if (predicates.isEmpty) { + val alreadyProcessed = j.getTagValue(processedJoinConditionTag).exists { condition => + condition.semanticEquals(joinCondition) + } + + lazy val filtersOfBothSide = splitConjunctivePredicates(joinCondition).filter { f => + f.deterministic && f.references.nonEmpty && + !f.references.subsetOf(left.outputSet) && !f.references.subsetOf(right.outputSet) + } + lazy val leftExtraCondition = + filtersOfBothSide.flatMap(extractPredicatesWithinOutputSet(_, left.outputSet)) + lazy val rightExtraCondition = + filtersOfBothSide.flatMap(extractPredicatesWithinOutputSet(_, right.outputSet)) + + if (alreadyProcessed || (leftExtraCondition.isEmpty && rightExtraCondition.isEmpty)) { j } else { - val pushDownCandidates = predicates.filter(_.deterministic) - lazy val leftFilterConditions = - pushDownCandidates.filter(_.references.subsetOf(left.outputSet)) - lazy val rightFilterConditions = - pushDownCandidates.filter(_.references.subsetOf(right.outputSet)) - lazy val newLeft = - leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + leftExtraCondition.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) lazy val newRight = - rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + rightExtraCondition.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) - joinType match { + val newJoin = joinType match { case _: InnerLike | LeftSemi => Join(newLeft, newRight, joinType, Some(joinCondition), hint) case RightOuter => @@ -63,6 +71,8 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe case other => throw new IllegalStateException(s"Unexpected join type: $other") } - } + newJoin.setTagValue(processedJoinConditionTag, joinCondition) + newJoin + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e95ef3d77d549..4a1299d82f7bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -545,19 +545,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val MAX_CNF_NODE_COUNT = - buildConf("spark.sql.optimizer.maxCNFNodeCount") - .internal() - .doc("Specifies the maximum allowable number of conjuncts in the result of CNF " + - "conversion. If the conversion exceeds the threshold, an empty sequence is returned. " + - "For example, CNF conversion of (a && b) || (c && d) generates " + - "four conjuncts (a || c) && (a || d) && (b || c) && (b || d).") - .version("3.1.0") - .intConf - .checkValue(_ >= 0, - "The depth of the maximum rewriting conjunction normal form must be positive.") - .createWithDefault(128) - val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -2954,8 +2941,6 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) - def maxCnfNodeCount: Int = getConf(MAX_CNF_NODE_COUNT) - def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMPRESSION_FACTOR) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala deleted file mode 100644 index 793abccd79405..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConjunctiveNormalFormPredicateSuite.scala +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.BooleanType - -class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHelper with PlanTest { - private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)).withQualifier(Seq("ta")) - private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)).withQualifier(Seq("tb")) - private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)).withQualifier(Seq("tc")) - private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)).withQualifier(Seq("td")) - private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)).withQualifier(Seq("te")) - private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)).withQualifier(Seq("tf")) - private val g = AttributeReference("G", BooleanType)(exprId = ExprId(7)).withQualifier(Seq("tg")) - private val h = AttributeReference("H", BooleanType)(exprId = ExprId(8)).withQualifier(Seq("th")) - private val i = AttributeReference("I", BooleanType)(exprId = ExprId(9)).withQualifier(Seq("ti")) - private val j = AttributeReference("J", BooleanType)(exprId = ExprId(10)).withQualifier(Seq("tj")) - private val a1 = - AttributeReference("a1", BooleanType)(exprId = ExprId(11)).withQualifier(Seq("ta")) - private val a2 = - AttributeReference("a2", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("ta")) - private val b1 = - AttributeReference("b1", BooleanType)(exprId = ExprId(12)).withQualifier(Seq("tb")) - - // Check CNF conversion with expected expression, assuming the input has non-empty result. - private def checkCondition(input: Expression, expected: Expression): Unit = { - val cnf = CNFWithGroupExpressionsByQualifier(input) - assert(cnf.nonEmpty) - val result = cnf.reduceLeft(And) - assert(result.semanticEquals(expected)) - } - - test("Keep non-predicated expressions") { - checkCondition(a, a) - checkCondition(Literal(1), Literal(1)) - } - - test("Conversion of Not") { - checkCondition(!a, !a) - checkCondition(!(!a), a) - checkCondition(!(!(a && b)), a && b) - checkCondition(!(!(a || b)), a || b) - checkCondition(!(a || b), !a && !b) - checkCondition(!(a && b), !a || !b) - } - - test("Conversion of And") { - checkCondition(a && b, a && b) - checkCondition(a && b && c, a && b && c) - checkCondition(a && (b || c), a && (b || c)) - checkCondition((a || b) && c, (a || b) && c) - checkCondition(a && b && c && d, a && b && c && d) - } - - test("Conversion of Or") { - checkCondition(a || b, a || b) - checkCondition(a || b || c, a || b || c) - checkCondition(a || b || c || d, a || b || c || d) - checkCondition((a && b) || c, (a || c) && (b || c)) - checkCondition((a && b) || (c && d), (a || c) && (a || d) && (b || c) && (b || d)) - } - - test("More complex cases") { - checkCondition(a && !(b || c), a && !b && !c) - checkCondition((a && b) || !(c && d), (a || !c || !d) && (b || !c || !d)) - checkCondition(a || b || c && d, (a || b || c) && (a || b || d)) - checkCondition(a || (b && c || d), (a || b || d) && (a || c || d)) - checkCondition(a && !(b && c || d && e), a && (!b || !c) && (!d || !e)) - checkCondition(((a && b) || c) || (d || e), (a || c || d || e) && (b || c || d || e)) - - checkCondition( - (a && b && c) || (d && e && f), - (a || d) && (a || e) && (a || f) && (b || d) && (b || e) && (b || f) && - (c || d) && (c || e) && (c || f) - ) - } - - test("Aggregate predicate of same qualifiers to avoid expanding") { - checkCondition(((a && b && a1) || c), ((a && a1) || c) && (b ||c)) - checkCondition(((a && a1 && b) || c), ((a && a1) || c) && (b ||c)) - checkCondition(((b && d && a && a1) || c), ((a && a1) || c) && (b ||c) && (d || c)) - checkCondition(((b && a2 && d && a && a1) || c), ((a2 && a && a1) || c) && (b ||c) && (d || c)) - checkCondition(((b && d && a && a1 && b1) || c), - ((a && a1) || c) && ((b && b1) ||c) && (d || c)) - checkCondition((a && a1) || (b && b1), (a && a1) || (b && b1)) - checkCondition((a && a1 && c) || (b && b1), ((a && a1) || (b && b1)) && (c || (b && b1))) - } - - test("Return Seq.empty when exceeding MAX_CNF_NODE_COUNT") { - // The following expression contains 36 conjunctive sub-expressions in CNF - val input = (a && b && c) || (d && e && f) || (g && h && i && j) - // The following expression contains 9 conjunctive sub-expressions in CNF - val input2 = (a && b && c) || (d && e && f) - Seq(8, 9, 10, 35, 36, 37).foreach { maxCount => - withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) { - if (maxCount < 36) { - assert(CNFWithGroupExpressionsByQualifier(input).isEmpty) - } else { - assert(CNFWithGroupExpressionsByQualifier(input).nonEmpty) - } - if (maxCount < 9) { - assert(CNFWithGroupExpressionsByQualifier(input2).isEmpty) - } else { - assert(CNFWithGroupExpressionsByQualifier(input2).nonEmpty) - } - } - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala new file mode 100644 index 0000000000000..ed141ef923e0a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExtractPredicatesWithinOutputSetSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types.BooleanType + +class ExtractPredicatesWithinOutputSetSuite + extends SparkFunSuite + with PredicateHelper + with PlanTest { + private val a = AttributeReference("A", BooleanType)(exprId = ExprId(1)) + private val b = AttributeReference("B", BooleanType)(exprId = ExprId(2)) + private val c = AttributeReference("C", BooleanType)(exprId = ExprId(3)) + private val d = AttributeReference("D", BooleanType)(exprId = ExprId(4)) + private val e = AttributeReference("E", BooleanType)(exprId = ExprId(5)) + private val f = AttributeReference("F", BooleanType)(exprId = ExprId(6)) + private val g = AttributeReference("G", BooleanType)(exprId = ExprId(7)) + private val h = AttributeReference("H", BooleanType)(exprId = ExprId(8)) + private val i = AttributeReference("I", BooleanType)(exprId = ExprId(9)) + + private def checkCondition( + input: Expression, + convertibleAttributes: Seq[Attribute], + expected: Option[Expression]): Unit = { + val result = extractPredicatesWithinOutputSet(input, AttributeSet(convertibleAttributes)) + if (expected.isEmpty) { + assert(result.isEmpty) + } else { + assert(result.isDefined && result.get.semanticEquals(expected.get)) + } + } + + test("Convertible conjunctive predicates") { + checkCondition(a && b, Seq(a, b), Some(a && b)) + checkCondition(a && b, Seq(a), Some(a)) + checkCondition(a && b, Seq(b), Some(b)) + checkCondition(a && b && c, Seq(a, c), Some(a && c)) + checkCondition(a && b && c && d, Seq(b, c), Some(b && c)) + } + + test("Convertible disjunctive predicates") { + checkCondition(a || b, Seq(a, b), Some(a || b)) + checkCondition(a || b, Seq(a), None) + checkCondition(a || b, Seq(b), None) + checkCondition(a || b || c, Seq(a, c), None) + checkCondition(a || b || c || d, Seq(a, b, d), None) + checkCondition(a || b || c || d, Seq(d, c, b, a), Some(a || b || c || d)) + } + + test("Convertible complex predicates") { + checkCondition((a && b) || (c && d), Seq(a, c), Some(a || c)) + checkCondition((a && b) || (c && d), Seq(a, b), None) + checkCondition((a && b) || (c && d), Seq(a, c, d), Some(a || (c && d))) + checkCondition((a && b && c) || (d && e && f), Seq(a, c, d, f), Some((a && c) || (d && f))) + checkCondition((a && b) || (c && d) || (e && f) || (g && h), Seq(a, c, e, g), + Some(a || c || e || g)) + checkCondition((a && b) || (c && d) || (e && f) || (g && h), Seq(a, e, g), None) + checkCondition((a || b) || (c && d) || (e && f) || (g && h), Seq(a, c, e, g), None) + checkCondition((a || b) || (c && d) || (e && f) || (g && h), Seq(a, b, c, e, g), + Some(a || b || c || e || g)) + checkCondition((a && b && c) || (d && e && f) || (g && h && i), Seq(b, e, h), Some(b || e || h)) + checkCondition((a && b && c) || (d && e && f) || (g && h && i), Seq(b, e, d), None) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index bb7e9d04c12d9..cf92e25ccab48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,9 +33,6 @@ class FilterPushdownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - override protected val excludedOnceBatches: Set[String] = - Set("Push CNF predicate through join") - val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: @@ -45,8 +42,9 @@ class FilterPushdownSuite extends PlanTest { BooleanSimplification, PushPredicateThroughJoin, CollapseProject) :: - Batch("Push CNF predicate through join", Once, - PushCNFPredicateThroughJoin) :: Nil + Batch("Push extra predicate through join", FixedPoint(10), + PushExtraPredicateThroughJoin, + PushDownPredicates) :: Nil } val attrA = 'a.int @@ -60,7 +58,7 @@ class FilterPushdownSuite extends PlanTest { val simpleDisjunctivePredicate = ("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11) - val expectedCNFPredicatePushDownResult = { + val expectedPredicatePushDownResult = { val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) val right = testRelation.where('a > 13 || 'a > 11).subquery('y) left.join(right, condition = Some("x.b".attr === "y.b".attr @@ -1247,17 +1245,17 @@ class FilterPushdownSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), expected) } - test("inner join: rewrite filter predicates to conjunctive normal form") { + test("push down filter predicates through inner join") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) val originalQuery = x.join(y).where(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, expectedCNFPredicatePushDownResult) + comparePlans(optimized, expectedPredicatePushDownResult) } - test("inner join: rewrite join predicates to conjunctive normal form") { + test("push down join predicates through inner join") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1265,10 +1263,10 @@ class FilterPushdownSuite extends PlanTest { x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate))) val optimized = Optimize.execute(originalQuery.analyze) - comparePlans(optimized, expectedCNFPredicatePushDownResult) + comparePlans(optimized, expectedPredicatePushDownResult) } - test("inner join: rewrite complex join predicates to conjunctive normal form") { + test("push down complex predicates through inner join") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1288,7 +1286,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join: rewrite join predicates(with NOT predicate) to conjunctive normal form") { + test("push down predicates(with NOT predicate) through inner join") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1308,7 +1306,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("left join: rewrite join predicates to conjunctive normal form") { + test("push down predicates through left join") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1327,7 +1325,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("right join: rewrite join predicates to conjunctive normal form") { + test("push down predicates through right join") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1346,7 +1344,7 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") { + test("SPARK-32302: avoid generating too many predicates") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) @@ -1364,30 +1362,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") { + test("push down predicate through multiple joins") { val x = testRelation.subquery('x) val y = testRelation.subquery('y) + val z = testRelation.subquery('z) + val xJoinY = x.join(y, condition = Some("x.b".attr === "y.b".attr)) + val originalQuery = z.join(xJoinY, + condition = Some("x.a".attr === "z.a".attr && simpleDisjunctivePredicate)) - val originalQuery = - x.join(y, condition = Some(("x.b".attr === "y.b".attr) - && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) - || (("y.a".attr > 2) && ("y.c".attr < 1))))) - - Seq(0, 10).foreach { count => - withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> count.toString) { - val optimized = Optimize.execute(originalQuery.analyze) - val (left, right) = if (count == 0) { - (testRelation.subquery('x), testRelation.subquery('y)) - } else { - (testRelation.subquery('x), - testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)) - } - val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr - && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) - || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze - - comparePlans(optimized, correctAnswer) - } - } + val optimized = Optimize.execute(originalQuery.analyze) + val left = x.where('a > 3 || 'a > 1) + val right = y.where('a > 13 || 'a > 11) + val correctAnswer = z.join(left.join(right, + condition = Some("x.b".attr === "y.b".attr && simpleDisjunctivePredicate)), + condition = Some("x.a".attr === "z.a".attr)).analyze + comparePlans(optimized, correctAnswer) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 576a826faf894..0c56e7675da6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -56,8 +56,10 @@ private[sql] object PruneFileSourcePartitions val (partitionFilters, dataFilters) = normalizedFilters.partition(f => f.references.subsetOf(partitionSet) ) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) - (ExpressionSet(partitionFilters), dataFilters) + (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) } private def rebuildPhysicalOperation( @@ -88,10 +90,8 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) - val finalPredicates = if (predicates.nonEmpty) predicates else filters val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, finalPredicates, + fsRelation.sparkSession, logicalRelation, partitionSchema, filters, logicalRelation.output) if (partitionKeyFilters.nonEmpty) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala index c4885f2842597..f6aff10cbc147 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitions.scala @@ -54,9 +54,8 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) val normalizedFilters = DataSourceStrategy.normalizeExprs( filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), relation.output) val partitionColumnSet = AttributeSet(relation.partitionCols) - ExpressionSet(normalizedFilters.filter { f => - !f.references.isEmpty && f.references.subsetOf(partitionColumnSet) - }) + ExpressionSet( + normalizedFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionColumnSet))) } /** @@ -103,9 +102,7 @@ private[sql] class PruneHiveTablePartitions(session: SparkSession) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case op @ PhysicalOperation(projections, filters, relation: HiveTableRelation) if filters.nonEmpty && relation.isPartitioned && relation.prunedPartitions.isEmpty => - val predicates = CNFWithGroupExpressionsByReference(filters.reduceLeft(And)) - val finalPredicates = if (predicates.nonEmpty) predicates else filters - val partitionKeyFilters = getPartitionKeyFilters(finalPredicates, relation) + val partitionKeyFilters = getPartitionKeyFilters(filters, relation) if (partitionKeyFilters.nonEmpty) { val newPartitions = prunePartitions(relation, partitionKeyFilters) val newTableMeta = updateTableMeta(relation.tableMeta, newPartitions) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala index c29e889c3a941..06aea084330fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruneHiveTablePartitionsSuite.scala @@ -55,6 +55,31 @@ class PruneHiveTablePartitionsSuite extends PrunePartitionSuiteBase { } } + test("Avoid generating too many predicates in partition pruning") { + withTempView("temp") { + withTable("t") { + sql( + s""" + |CREATE TABLE t(i INT, p0 INT, p1 INT) + |USING $format + |PARTITIONED BY (p0, p1)""".stripMargin) + + spark.range(0, 10, 1).selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- (0 to 25)) { + sql( + s""" + |INSERT OVERWRITE TABLE t PARTITION (p0='$part', p1='$part') + |SELECT col FROM temp""".stripMargin) + } + val scale = 20 + val predicate = (1 to scale).map(i => s"(p0 = '$i' AND p1 = '$i')").mkString(" OR ") + assertPrunedPartitions(s"SELECT * FROM t WHERE $predicate", scale) + } + } + } + override def getScanExecPartitionSize(plan: SparkPlan): Long = { plan.collectFirst { case p: HiveTableScanExec => p