Skip to content

Commit

Permalink
address comments from Bart
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jan 8, 2019
1 parent 8dafc64 commit 3e8c171
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,30 @@ import org.apache.spark.sql.types._
* We need to take care of special floating numbers (NaN and -0.0) in several places:
* 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
* treated as same.
* 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
* to the same group.
* 2. In aggregate grouping keys, different NaNs should belong to the same group, -0.0 and 0.0
* should belong to the same group.
* 3. In join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
* treated as same.
* 4. In window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
* should be treated as same.
* 4. In window partition keys, different NaNs should belong to the same partition, -0.0 and 0.0
* should belong to the same partition.
*
* Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
* recursively compare the fields/elements, so it's also fine.
*
* Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
* NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
* Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window partition keys into
* binary `UnsafeRow` and compare the binary data directly. Different NaNs have different binary
* representation, and the same thing happens for -0.0 and 0.0.
*
* This rule normalizes NaN and -0.0 in Window partition keys, Join keys and Aggregate grouping
* expressions.
* This rule normalizes NaN and -0.0 in window partition keys, join keys and aggregate grouping
* keys.
*
* Note that, this rule should be an analyzer rule, as it must be applied to make the query result
* corrected. Currently it's executed as an optimizer rule, because the optimizer may create new
* joins(for subquery) and reorder joins(may change the join condition), and this rule needs to be
* executed at the end.
* Ideally we should do the normalization in the physical operators that compare the
* binary `UnsafeRow` directly. We don't need this normalization if the Spark SQL execution engine
* is not optimized to run on binary data. This rule is created to simplify the implementation, so
* that we have a single place to do normalization, which is more maintainable.
*
* Note that, this rule must be executed at the end of optimizer, because the optimizer may create
* new joins(the subquery rewrite) and new join conditions(the join reorder).
*/
object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

Expand All @@ -58,10 +62,18 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

case _ => plan transform {
case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) =>
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
// to normalize the `windowExpressions`, as they are executed per input row and should take
// the input row as it is.
w.copy(partitionSpec = w.partitionSpec.map(normalize))

case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _)
if leftKeys.exists(k => needNormalize(k.dataType)) =>
// Only hash join and sort merge join need the normalization. Here we catch all Joins with
// join keys, assuming Joins with join keys are always planned as hash join or sort merge
// join. It's very unlikely that we will break this assumption in the near future.
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
// The analyzer guarantees left and right joins keys are of the same data type. Here we
// only need to check join keys of one side.
if leftKeys.exists(k => needNormalize(k.dataType)) =>
val newLeftJoinKeys = leftKeys.map(normalize)
val newRightJoinKeys = rightKeys.map(normalize)
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
Expand All @@ -79,7 +91,9 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
case ArrayType(et, _) => needNormalize(et)
// We don't need to handle MapType here, as it's not comparable.
// Currently MapType is not comparable and analyzer should fail earlier if this case happens.
case _: MapType =>
throw new IllegalStateException("grouping/join/window partition keys cannot be map type.")
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
}

test("SPARK-26021: NaN and -0.0 in grouping expressions") {
import java.lang.Float.floatToRawIntBits
import java.lang.Double.doubleToRawLongBits

// 0.0/0.0 and NaN are different values.
assert(floatToRawIntBits(0.0f/0.0f) != floatToRawIntBits(Float.NaN))
assert(doubleToRawLongBits(0.0/0.0) != doubleToRawLongBits(Double.NaN))

checkAnswer(
Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(),
Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil)
Expand All @@ -734,19 +741,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {

// test with complicated type grouping expressions
checkAnswer(
Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy(array("f"), struct("f")).count(),
Row(Seq(0.0f), Row(0.0f), 2) :: Row(Seq(Float.NaN), Row(Float.NaN), 2) :: Nil)
Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f")
.groupBy(array("f"), struct("f")).count(),
Row(Seq(0.0f), Row(0.0f), 2) ::
Row(Seq(Float.NaN), Row(Float.NaN), 2) :: Nil)
checkAnswer(
Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d")
.groupBy(array("d"), struct("d")).count(),
Row(Seq(0.0d), Row(0.0d), 2) ::
Row(Seq(Double.NaN), Row(Double.NaN), 2) :: Nil)

checkAnswer(
Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f")
.groupBy(array(struct("f")), struct(array("f"))).count(),
Row(Seq(Row(0.0f)), Row(Seq(0.0f)), 2) ::
Row(Seq(Row(Float.NaN)), Row(Seq(Float.NaN)), 2) :: Nil)
checkAnswer(
Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d").groupBy(array("d"), struct("d")).count(),
Row(Seq(0.0d), Row(0.0d), 2) :: Row(Seq(Double.NaN), Row(Double.NaN), 2) :: Nil)
Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d")
.groupBy(array(struct("d")), struct(array("d"))).count(),
Row(Seq(Row(0.0d)), Row(Seq(0.0d)), 2) ::
Row(Seq(Row(Double.NaN)), Row(Seq(Double.NaN)), 2) :: Nil)

// test with complicated type grouping columns
val df = Seq(
Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN),
Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru")
(Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))),
(Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0)))
).toDF("arr", "stru", "arrOfStru")
checkAnswer(
df.groupBy("arr", "stru").count(),
Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), 2)
df.groupBy("arr", "stru", "arrOfStru").count(),
Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,13 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
}

test("NaN and -0.0 in window partition keys") {
import java.lang.Float.floatToRawIntBits
import java.lang.Double.doubleToRawLongBits

// 0.0/0.0 and NaN are different values.
assert(floatToRawIntBits(0.0f/0.0f) != floatToRawIntBits(Float.NaN))
assert(doubleToRawLongBits(0.0/0.0) != doubleToRawLongBits(Double.NaN))

val df = Seq(
(Float.NaN, Double.NaN),
(0.0f/0.0f, 0.0/0.0),
Expand All @@ -712,8 +719,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row(-0.0f, 2)))

// test with complicated window partition keys.
val windowSpec1 = Window.partitionBy(array("f"), struct("d"))
checkAnswer(
df.select($"f", count(lit(1)).over(windowSpec1)),
Seq(
Row(Float.NaN, 2),
Row(0.0f/0.0f, 2),
Row(0.0f, 2),
Row(-0.0f, 2)))

val windowSpec2 = Window.partitionBy(array(struct("f")), struct(array("d")))
checkAnswer(
df.select($"f", count(lit(1)).over(Window.partitionBy(array("f"), struct("d")))),
df.select($"f", count(lit(1)).over(windowSpec2)),
Seq(
Row(Float.NaN, 2),
Row(0.0f/0.0f, 2),
Expand All @@ -722,12 +739,14 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {

// test with df with complicated-type columns.
val df2 = Seq(
Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN),
Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru")
(Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))),
(Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0)))
).toDF("arr", "stru", "arrOfStru")
val windowSpec3 = Window.partitionBy("arr", "stru", "arrOfStru")
checkAnswer(
df2.select($"arr", $"stru", count(lit(1)).over(Window.partitionBy("arr", "stru"))),
df2.select($"arr", $"stru", $"arrOfStru", count(lit(1)).over(windowSpec3)),
Seq(
Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), 2),
Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), 2)))
Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2),
Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2)))
}
}
54 changes: 34 additions & 20 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -918,41 +918,55 @@ class JoinSuite extends QueryTest with SharedSQLContext {
checkAnswer(
sql(
"""
|SELECT v1.f, v1.d
|SELECT v1.f, v1.d, v2.f, v2.d
|FROM v1 JOIN v2
|ON v1.f = v2.f AND v1.d = v2.d
""".stripMargin),
Seq(
Row(Float.NaN, Double.NaN),
Row(0.0f, 0.0),
Row(0.0f, 0.0),
Row(-0.0f, -0.0),
Row(-0.0f, -0.0)))
Row(Float.NaN, Double.NaN, Float.NaN, Double.NaN),
Row(0.0f, 0.0, 0.0f, 0.0),
Row(0.0f, 0.0, -0.0f, -0.0),
Row(-0.0f, -0.0, 0.0f, 0.0),
Row(-0.0f, -0.0, -0.0f, -0.0)))

// test with complicated join keys.
checkAnswer(
sql(
"""
|SELECT v1.f, v1.d
|SELECT v1.f, v1.d, v2.f, v2.d
|FROM v1 JOIN v2
|ON array(v1.f) = array(v2.f) AND struct(v1.d) = struct(v2.d)
|ON
| array(v1.f) = array(v2.f) AND
| struct(v1.d) = struct(v2.d) AND
| array(struct(v1.f, v1.d)) = array(struct(v2.f, v2.d)) AND
| struct(array(v1.f), array(v1.d)) = struct(array(v2.f), array(v2.d))
""".stripMargin),
Seq(
Row(Float.NaN, Double.NaN),
Row(0.0f, 0.0),
Row(0.0f, 0.0),
Row(-0.0f, -0.0),
Row(-0.0f, -0.0)))
Row(Float.NaN, Double.NaN, Float.NaN, Double.NaN),
Row(0.0f, 0.0, 0.0f, 0.0),
Row(0.0f, 0.0, -0.0f, -0.0),
Row(-0.0f, -0.0, 0.0f, 0.0),
Row(-0.0f, -0.0, -0.0f, -0.0)))

// test with tables with complicated-type columns.
Seq(Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN)).toDF("arr", "stru").createTempView("v3")
Seq(Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru").createTempView("v4")
Seq((Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))))
.toDF("arr", "stru", "arrOfStru").createTempView("v3")
Seq((Array(0.0f, -0.0f), Tuple2(0.0d, 0.0/0.0), Seq(Tuple2(0.0d, 0.0/0.0))))
.toDF("arr", "stru", "arrOfStru").createTempView("v4")
checkAnswer(
sql("SELECT v3.arr, v3.stru FROM v3 JOIN v4 ON v3.arr = v4.arr AND v3.stru = v4.stru"),
Seq(Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN))))
checkAnswer(
sql("SELECT v4.arr, v4.stru FROM v3 JOIN v4 ON v3.arr = v4.arr AND v3.stru = v4.stru"),
Seq(Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN))))
sql(
"""
|SELECT v3.arr, v3.stru, v3.arrOfStru, v4.arr, v4.stru, v4.arrOfStru
|FROM v3 JOIN v4
|ON v3.arr = v4.arr AND v3.stru = v4.stru AND v3.arrOfStru = v4.arrOfStru
""".stripMargin),
Seq(Row(
Seq(-0.0f, 0.0f),
Row(-0.0d, Double.NaN),
Seq(Row(-0.0d, Double.NaN)),
Seq(0.0f, -0.0f),
Row(0.0d, 0.0/0.0),
Seq(Row(0.0d, 0.0/0.0)))))
}
}
}

0 comments on commit 3e8c171

Please sign in to comment.