Skip to content

Commit

Permalink
fix test case; reduce threshold default value
Browse files Browse the repository at this point in the history
  • Loading branch information
gengliangwang committed Jun 6, 2020
1 parent a9a5c0b commit cbc1220
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
val left: Seq[Expression] = resultStack.pop()
left ++ right
case _: Or =>
val right: Seq[Expression] = resultStack.pop()
val left: Seq[Expression] = resultStack.pop()
// For each side, there is no need to expand predicates of the same references.
// So here we can aggregate predicates of the same references as one single predicate,
// for reducing the size of pushed down predicates and corresponding codegen.
val right = aggregateExpressionsOfSameReference(resultStack.pop())
val left = aggregateExpressionsOfSameReference(resultStack.pop())
// Stop the loop whenever the result exceeds the `maxCnfNodeCount`
if (left.size * right.size > maxCnfNodeCount) {
Seq.empty
Expand All @@ -75,6 +78,9 @@ object PushCNFPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelpe
resultStack.top
}

private def aggregateExpressionsOfSameReference(expressions: Seq[Expression]): Seq[Expression] = {
expressions.groupBy(_.references.map(_.qualifier)).map(_._2.reduceLeft(And)).toSeq
}
/**
* Iterative post order traversal over a binary tree built by And/Or clauses.
* @param condition to be traversed as binary tree
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ object SQLConf {
.intConf
.checkValue(_ >= 0,
"The depth of the maximum rewriting conjunction normal form must be positive.")
.createWithDefault(256)
.createWithDefault(128)

val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1293,14 +1293,7 @@ class FilterPushdownSuite extends PlanTest {
val left = testRelation.where(
('a === 5 || 'a === 2 || 'a === 1)).subquery('x)
val right = testRelation.where(
('a >= 2 || 'a >= 1 || 'a >= 9) &&
('a >= 2 || 'a >= 1 || 'a <= 27) &&
('a >= 2 || 'a <=14 || 'a >= 9) &&
('a >= 2 || 'a <=14 || 'a <= 27) &&
('a <= 3 || 'a >= 1 || 'a >= 9) &&
('a <= 3 || 'a >= 1 || 'a <= 27) &&
('a <= 3 || 'a <=14 || 'a >= 9) &&
('a <= 3 || 'a <=14 || 'a <= 27)).subquery('y)
('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y)
val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze

comparePlans(optimized, correctAnswer)
Expand Down Expand Up @@ -1367,6 +1360,25 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("inner join: rewrite to conjunctive normal form avoid generating too many predicates") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)

val originalQuery = {
x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) &&
("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1)))))
}

val optimized = Optimize.execute(originalQuery.analyze)
val left = testRelation.subquery('x)
val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y)
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr &&
((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) ||
(("y.a".attr > 2) && ("y.c".attr < 1))))).analyze

comparePlans(optimized, correctAnswer)
}

test(s"Disable rewrite to CNF by setting ${SQLConf.MAX_CNF_NODE_COUNT.key}=0") {
val x = testRelation.subquery('x)
val y = testRelation.subquery('y)
Expand All @@ -1384,7 +1396,7 @@ class FilterPushdownSuite extends PlanTest {
(testRelation.subquery('x), testRelation.subquery('y))
} else {
(testRelation.subquery('x),
testRelation.where(('c <= 5 || 'c < 1) && ('c <=5 || 'a > 2)).subquery('y))
testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y))
}
val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr
&& ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5))
Expand Down

0 comments on commit cbc1220

Please sign in to comment.