Skip to content

Commit

Permalink
[SPARK-26021][2.4][SQL][FOLLOWUP] only deal with NaN and -0.0 in Unsa…
Browse files Browse the repository at this point in the history
…feWriter

backport apache#23239 to 2.4

---------

## What changes were proposed in this pull request?

A followup of apache#23043

There are 4 places we need to deal with NaN and -0.0:
1. comparison expressions. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. Different NaNs should be assigned to the same group.
4. window partition keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.

The case 1 is OK. Our comparison already handles NaN and -0.0, and for struct/array/map, we will recursively compare the fields/elements.

Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.

To fix it, a simple solution is: normalize float/double when building unsafe data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to worry about it anymore.

Following this direction, this PR moves the handling of NaN and -0.0 from `Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not handle them, which reduces the perf overhead. It's also easier to add comments explaining why we do it in `UnsafeWriter`.

## How was this patch tested?

existing tests

Closes apache#23265 from cloud-fan/minor.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
cloud-fan authored and kai-chi committed Aug 1, 2019
1 parent 8307d42 commit a79b821
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 24 deletions.
10 changes: 0 additions & 10 deletions common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,6 @@ public static float getFloat(Object object, long offset) {
}

public static void putFloat(Object object, long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
_UNSAFE.putFloat(object, offset, value);
}

Expand All @@ -133,11 +128,6 @@ public static double getDouble(Object object, long offset) {
}

public static void putDouble(Object object, long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
_UNSAFE.putDouble(object, offset, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,18 +157,4 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}

@Test
// SPARK-26021
public void writeMinusZeroIsReplacedWithZero() {
byte[] doubleBytes = new byte[Double.BYTES];
byte[] floatBytes = new byte[Float.BYTES];
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);

Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,46 @@ protected final void writeLong(long offset, long value) {
Platform.putLong(getBuffer(), offset, value);
}

// We need to take care of NaN and -0.0 in several places:
// 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
// treated as same.
// 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
// to the same group.
// 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
// treated as same.
// 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
// should be treated as same.
//
// Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
// recursively compare the fields/elements, so it's also fine.
//
// Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
// NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
//
// Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
// float/double columns and nested fields to `UnsafeRow`.
//
// Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
// join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
// types, so nested float/double may not be normalized. We need to make sure that all the unsafe
// data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
// creation.
protected final void writeFloat(long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
Platform.putFloat(getBuffer(), offset, value);
}

// See comments for `writeFloat`.
protected final void writeDouble(long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
Platform.putDouble(getBuffer(), offset, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
assert(res1 == res2)
}

test("SPARK-26021: normalize float/double NaN and -0.0") {
val unsafeRowWriter1 = new UnsafeRowWriter(4)
unsafeRowWriter1.resetRowWriter()
unsafeRowWriter1.write(0, Float.NaN)
unsafeRowWriter1.write(1, Double.NaN)
unsafeRowWriter1.write(2, 0.0f)
unsafeRowWriter1.write(3, 0.0)
val res1 = unsafeRowWriter1.getRow

val unsafeRowWriter2 = new UnsafeRowWriter(4)
unsafeRowWriter2.resetRowWriter()
unsafeRowWriter2.write(0, 0.0f/0.0f)
unsafeRowWriter2.write(1, 0.0/0.0)
unsafeRowWriter2.write(2, -0.0f)
unsafeRowWriter2.write(3, -0.0)
val res2 = unsafeRowWriter2.getRow

// The two rows should be the equal
assert(res1 == res2)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
}
}

test("NaN and -0.0 in join keys") {
val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
val joined = df1.join(df2, Seq("f", "d"))
checkAnswer(joined, Seq(
Row(Float.NaN, Double.NaN),
Row(0.0f, 0.0),
Row(0.0f, 0.0),
Row(0.0f, 0.0),
Row(0.0f, 0.0)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -658,4 +658,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
}

test("NaN and -0.0 in window partition keys") {
val df = Seq(
(Float.NaN, Double.NaN, 1),
(0.0f/0.0f, 0.0/0.0, 1),
(0.0f, 0.0, 1),
(-0.0f, -0.0, 1)).toDF("f", "d", "i")
val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
checkAnswer(result, Seq(
Row(Float.NaN, 2),
Row(Float.NaN, 2),
Row(0.0f, 2),
Row(0.0f, 2)))
}
}

0 comments on commit a79b821

Please sign in to comment.