Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
AngersZhuuuu committed Jun 12, 2020
1 parent 478a7a8 commit 603660b
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ trait PredicateHelper extends Logging {
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def conjunctiveNormalForm(condition: Expression): Seq[Expression] = {
def conjunctiveNormalForm(
condition: Expression,
groupExpsFunc: Seq[Expression] => Seq[Expression] = _.toSeq): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
Expand All @@ -226,8 +228,8 @@ trait PredicateHelper extends Logging {
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = groupExpressionsByQualifier(resultStack.pop())
val left = groupExpressionsByQualifier(resultStack.pop())
val right = groupExpsFunc(resultStack.pop())
val left = groupExpsFunc(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
Expand All @@ -249,60 +251,16 @@ trait PredicateHelper extends Logging {
resultStack.top
}

private def groupExpressionsByQualifier(expressions: Seq[Expression]): Seq[Expression] = {
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
}

/**
* Convert an expression into conjunctive normal form for partition pruning.
* Definition and algorithm: https://en.wikipedia.org/wiki/Conjunctive_normal_form
* CNF can explode exponentially in the size of the input expression when converting [[Or]]
* clauses. Use a configuration [[SQLConf.MAX_CNF_NODE_COUNT]] to prevent such cases.
*
* @param condition to be converted into CNF.
* @return the CNF result as sequence of disjunctive expressions. If the number of expressions
* exceeds threshold on converting `Or`, `Seq.empty` is returned.
*/
def conjunctiveNormalFormForPartitionPruning(condition: Expression): Seq[Expression] = {
val postOrderNodes = postOrderTraversal(condition)
val resultStack = new mutable.Stack[Seq[Expression]]
val maxCnfNodeCount = SQLConf.get.maxCnfNodeCount
// Bottom up approach to get CNF of sub-expressions
while (postOrderNodes.nonEmpty) {
val cnf = postOrderNodes.pop() match {
case _: And =>
val right = resultStack.pop()
val left = resultStack.pop()
left ++ right
case _: Or =>
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same qualifier as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = groupExpressionsByReference(resultStack.pop())
val left = groupExpressionsByReference(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
logInfo(s"As the result size exceeds the threshold $maxCnfNodeCount. " +
"The CNF conversion is skipped and returning Seq.empty now. To avoid this, you can " +
s"raise the limit ${SQLConf.MAX_CNF_NODE_COUNT.key}.")
return Seq.empty
} else {
for { x <- left; y <- right } yield Or(x, y)
}
case other => other :: Nil
}
resultStack.push(cnf)
}
if (resultStack.length != 1) {
logWarning("The length of CNF conversion result stack is supposed to be 1. There might " +
"be something wrong with CNF conversion.")
return Seq.empty
}
resultStack.top
def conjunctiveNormalFormAndGroupExpsByQualifier(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition,
(expressions: Seq[Expression]) =>
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq)
}

private def groupExpressionsByReference(expressions: Seq[Expression]): Seq[Expression] = {
expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq
def conjunctiveNormalFormAndGroupExpsByReference(condition: Expression): Seq[Expression] = {
conjunctiveNormalForm(condition,
(expressions: Seq[Expression]) =>
expressions.groupBy(_.references).map(_._2.reduceLeft(And)).toSeq)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(left, right, joinType, Some(joinCondition), hint) =>
val predicates = conjunctiveNormalForm(joinCondition)
val predicates = conjunctiveNormalFormAndGroupExpsByQualifier(joinCondition)
if (predicates.isEmpty) {
j
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe

// Check CNF conversion with expected expression, assuming the input has non-empty result.
private def checkCondition(input: Expression, expected: Expression): Unit = {
val cnf = conjunctiveNormalForm(input)
val cnf = conjunctiveNormalFormAndGroupExpsByQualifier(input)
assert(cnf.nonEmpty)
val result = cnf.reduceLeft(And)
assert(result.semanticEquals(expected))
Expand Down Expand Up @@ -113,14 +113,14 @@ class ConjunctiveNormalFormPredicateSuite extends SparkFunSuite with PredicateHe
Seq(8, 9, 10, 35, 36, 37).foreach { maxCount =>
withSQLConf(SQLConf.MAX_CNF_NODE_COUNT.key -> maxCount.toString) {
if (maxCount < 36) {
assert(conjunctiveNormalForm(input).isEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).isEmpty)
} else {
assert(conjunctiveNormalForm(input).nonEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input).nonEmpty)
}
if (maxCount < 9) {
assert(conjunctiveNormalForm(input2).isEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).isEmpty)
} else {
assert(conjunctiveNormalForm(input2).nonEmpty)
assert(conjunctiveNormalFormAndGroupExpsByQualifier(input2).nonEmpty)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
*/
object PushCNFPredicateThroughFileScan extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case ScanOperation(projectList, conditions, relation: LogicalRelation)
if conditions.nonEmpty =>
val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And))
val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And))
if (predicates.isEmpty) {
plan
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object PushCNFPredicateThroughHiveTableScan extends Rule[LogicalPlan] with Predi
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case ScanOperation(projectList, conditions, relation: HiveTableRelation)
if conditions.nonEmpty =>
val predicates = conjunctiveNormalFormForPartitionPruning(conditions.reduceLeft(And))
val predicates = conjunctiveNormalFormAndGroupExpsByReference(conditions.reduceLeft(And))
if (predicates.isEmpty) {
plan
} else {
Expand Down

0 comments on commit 603660b

Please sign in to comment.