From ee0ef91a7047d47328efac753e66ec97a91c0e37 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Wed, 14 Nov 2018 18:18:30 +0200 Subject: [PATCH 1/7] replace -0.0 with 0.0 in BoundAttribute added tests --- .../catalyst/expressions/BoundAttribute.scala | 23 ++++++++++++--- .../GenerateUnsafeProjectionSuite.scala | 16 +++++++++++ .../spark/sql/DataFrameAggregateSuite.scala | 28 +++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 77582e10f9ff2..3f49424d2b0c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -56,17 +56,32 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { - ev.copy(code = + var codeBlock = code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); - """.stripMargin) + """.stripMargin + codeBlock = codeBlock + genReplaceMinusZeroWithZeroCode(javaType.codeString, ev.value) + ev.copy(code = codeBlock) } else { - ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + var codeBlock = code"$javaType ${ev.value} = $value;" + codeBlock = codeBlock + genReplaceMinusZeroWithZeroCode(javaType.codeString, ev.value) + ev.copy(code = codeBlock, isNull = FalseLiteral) } } } + + private def genReplaceMinusZeroWithZeroCode(javaType: String, value: String): Block = { + val code: String = s"\nif ($value == -0.0%c) $value = 0.0%c;" + var formattedCode: String = "" + javaType match { + case "double" | "java.lang.Double" => formattedCode = code.format('d', 'd') + case "float" | "java.lang.Float" => formattedCode = code.format('f', 'f') + case _ => + } + code"$formattedCode" + } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index 01aa3579aea98..17bca948369b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -68,6 +68,22 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite { assert(!result3.getStruct(0, 2).isNullAt(0)) assert(!result3.getStruct(0, 3).isNullAt(0)) } + + test("SPARK-26021: Test replacing -0.0 with 0.0") { + val exprs = + BoundReference(0, DoubleType, nullable = false) :: + BoundReference(1, DoubleType, nullable = true) :: + BoundReference(2, FloatType, nullable = false) :: + BoundReference(3, FloatType, nullable = true) :: + Nil + val projection = GenerateUnsafeProjection.generate(exprs) + val result = projection.apply(InternalRow(-0.0d, Double.box(-0.0d), -0.0f, Float.box(-0.0f))) + // using compare since 0.0 == -0.0 is true + assert(result.getDouble(0).compareTo(0.0d) == 0) + assert(result.getDouble(1).compareTo(0.0d) == 0) + assert(result.getFloat(2).compareTo(0.0f) == 0) + assert(result.getFloat(3).compareTo(0.0f) == 0) + } } object AlwaysNull extends InternalRow { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d9ba6e2ce5120..d34933e13ae05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -723,4 +723,32 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { + val colName = "i" + def groupByCollect(df: DataFrame): Array[Row] = { + df.groupBy(colName).count().collect() + } + def assertResult[T](result: Array[Row], zero: T)(implicit ordering: Ordering[T]): Unit = { + assert(result.length == 1) + // using compare since 0.0 == -0.0 is true + assert(ordering.compare(result(0).getAs[T](0), zero) == 0) + assert(result(0).getLong(1) == 3) + } + + spark.conf.set("spark.sql.codegen.wholeStage", "false") + val doubles = + groupByCollect(Seq(0.0d, 0.0d, -0.0d).toDF(colName)) + val doublesBoxed = + groupByCollect(Seq(Double.box(0.0d), Double.box(0.0d), Double.box(-0.0d)).toDF(colName)) + val floats = + groupByCollect(Seq(0.0f, -0.0f, 0.0f).toDF(colName)) + val floatsBoxed = + groupByCollect(Seq(Float.box(0.0f), Float.box(-0.0f), Float.box(0.0f)).toDF(colName)) + + assertResult(doubles, 0.0d) + assertResult(doublesBoxed, 0.0d) + assertResult(floats, 0.0f) + assertResult(floatsBoxed, 0.0f) + } } From 63b7f59ad44d0876ea6dde02e4204fc0140d0df6 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Wed, 14 Nov 2018 18:27:24 +0200 Subject: [PATCH 2/7] minor remove var type --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 3f49424d2b0c9..f87fe367a8e83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -73,8 +73,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } private def genReplaceMinusZeroWithZeroCode(javaType: String, value: String): Block = { - val code: String = s"\nif ($value == -0.0%c) $value = 0.0%c;" - var formattedCode: String = "" + val code = s"\nif ($value == -0.0%c) $value = 0.0%c;" + var formattedCode = "" javaType match { case "double" | "java.lang.Double" => formattedCode = code.format('d', 'd') case "float" | "java.lang.Float" => formattedCode = code.format('f', 'f') From f48d4ef4ba90c08f82c92d32325ecf4dc1d05ab4 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Sat, 17 Nov 2018 20:05:00 +0200 Subject: [PATCH 3/7] revert + replace -0 in Platform.setDouble/Float --- .../org/apache/spark/unsafe/Platform.java | 6 +++++ .../spark/unsafe/PlatformUtilSuite.java | 13 +++++++++++ .../catalyst/expressions/BoundAttribute.scala | 23 ++++--------------- .../GenerateUnsafeProjectionSuite.scala | 16 ------------- 4 files changed, 23 insertions(+), 35 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index aca6fca00c48b..e51eed8ce2653 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -120,6 +120,9 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { + if(value == -0.0f) { + value = 0.0f; + } _UNSAFE.putFloat(object, offset, value); } @@ -128,6 +131,9 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { + if(value == -0.0d) { + value = 0.0d; + } _UNSAFE.putDouble(object, offset, value); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9c..9257cfa46830f 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -24,6 +24,8 @@ import org.junit.Assert; import org.junit.Test; +import java.nio.ByteBuffer; + public class PlatformUtilSuite { @Test @@ -157,4 +159,15 @@ public void heapMemoryReuse() { Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); Assert.assertEquals(obj3, onheap4.getBaseObject()); } + + @Test + // SPARK-26021 + public void writeMinusZeroIsReplacedWithZero() { + byte[] doubleBytes = new byte[Double.BYTES]; + byte[] floatBytes = new byte[Float.BYTES]; + Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); + Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); + Assert.assertEquals(0, Double.compare(0.0d, ByteBuffer.wrap(doubleBytes).getDouble())); + Assert.assertEquals(0, Float.compare(0.0f, ByteBuffer.wrap(floatBytes).getFloat())); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index f87fe367a8e83..77582e10f9ff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree -import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ @@ -56,32 +56,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val javaType = JavaCode.javaType(dataType) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { - var codeBlock = + ev.copy(code = code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); - """.stripMargin - codeBlock = codeBlock + genReplaceMinusZeroWithZeroCode(javaType.codeString, ev.value) - ev.copy(code = codeBlock) + """.stripMargin) } else { - var codeBlock = code"$javaType ${ev.value} = $value;" - codeBlock = codeBlock + genReplaceMinusZeroWithZeroCode(javaType.codeString, ev.value) - ev.copy(code = codeBlock, isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } - - private def genReplaceMinusZeroWithZeroCode(javaType: String, value: String): Block = { - val code = s"\nif ($value == -0.0%c) $value = 0.0%c;" - var formattedCode = "" - javaType match { - case "double" | "java.lang.Double" => formattedCode = code.format('d', 'd') - case "float" | "java.lang.Float" => formattedCode = code.format('f', 'f') - case _ => - } - code"$formattedCode" - } } object BindReferences extends Logging { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala index 17bca948369b6..01aa3579aea98 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjectionSuite.scala @@ -68,22 +68,6 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite { assert(!result3.getStruct(0, 2).isNullAt(0)) assert(!result3.getStruct(0, 3).isNullAt(0)) } - - test("SPARK-26021: Test replacing -0.0 with 0.0") { - val exprs = - BoundReference(0, DoubleType, nullable = false) :: - BoundReference(1, DoubleType, nullable = true) :: - BoundReference(2, FloatType, nullable = false) :: - BoundReference(3, FloatType, nullable = true) :: - Nil - val projection = GenerateUnsafeProjection.generate(exprs) - val result = projection.apply(InternalRow(-0.0d, Double.box(-0.0d), -0.0f, Float.box(-0.0f))) - // using compare since 0.0 == -0.0 is true - assert(result.getDouble(0).compareTo(0.0d) == 0) - assert(result.getDouble(1).compareTo(0.0d) == 0) - assert(result.getFloat(2).compareTo(0.0f) == 0) - assert(result.getFloat(3).compareTo(0.0f) == 0) - } } object AlwaysNull extends InternalRow { From 28bd429429549094471c93eb5145c96804b13a17 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Mon, 19 Nov 2018 17:21:58 +0000 Subject: [PATCH 4/7] move isNan check to Platform --- .../src/main/java/org/apache/spark/unsafe/Platform.java | 8 ++++++-- .../apache/spark/sql/catalyst/expressions/UnsafeRow.java | 6 ------ .../sql/catalyst/expressions/codegen/UnsafeWriter.java | 6 ------ 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index e51eed8ce2653..46dc481a09d98 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -120,7 +120,9 @@ public static float getFloat(Object object, long offset) { } public static void putFloat(Object object, long offset, float value) { - if(value == -0.0f) { + if (Float.isNaN(value)) { + value = Float.NaN; + } else if(value == -0.0f) { value = 0.0f; } _UNSAFE.putFloat(object, offset, value); @@ -131,7 +133,9 @@ public static double getDouble(Object object, long offset) { } public static void putDouble(Object object, long offset, double value) { - if(value == -0.0d) { + if (Double.isNaN(value)) { + value = Double.NaN; + } else if(value == -0.0d) { value = 0.0d; } _UNSAFE.putDouble(object, offset, value); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index a76e6ef8c91c1..9bf9452855f5f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 2781655002000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) { } protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } Platform.putFloat(getBuffer(), offset, value); } protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } Platform.putDouble(getBuffer(), offset, value); } } From 20d56ebdcf81d04548509a51c38884e3549f38e3 Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Tue, 20 Nov 2018 10:16:21 +0000 Subject: [PATCH 5/7] simplify test --- .../spark/sql/DataFrameAggregateSuite.scala | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d34933e13ae05..e2dc7ad307f60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -726,29 +726,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { val colName = "i" - def groupByCollect(df: DataFrame): Array[Row] = { - df.groupBy(colName).count().collect() - } - def assertResult[T](result: Array[Row], zero: T)(implicit ordering: Ordering[T]): Unit = { - assert(result.length == 1) - // using compare since 0.0 == -0.0 is true - assert(ordering.compare(result(0).getAs[T](0), zero) == 0) - assert(result(0).getLong(1) == 3) - } - - spark.conf.set("spark.sql.codegen.wholeStage", "false") - val doubles = - groupByCollect(Seq(0.0d, 0.0d, -0.0d).toDF(colName)) - val doublesBoxed = - groupByCollect(Seq(Double.box(0.0d), Double.box(0.0d), Double.box(-0.0d)).toDF(colName)) - val floats = - groupByCollect(Seq(0.0f, -0.0f, 0.0f).toDF(colName)) - val floatsBoxed = - groupByCollect(Seq(Float.box(0.0f), Float.box(-0.0f), Float.box(0.0f)).toDF(colName)) - - assertResult(doubles, 0.0d) - assertResult(doublesBoxed, 0.0d) - assertResult(floats, 0.0f) - assertResult(floatsBoxed, 0.0f) + val doubles = Seq(0.0d, 0.0d, -0.0d).toDF(colName).groupBy(colName).count().collect() + val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() + + assert(doubles.length == 1) + assert(floats.length == 1) + // using compare since 0.0 == -0.0 is true + assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) + assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) + assert(doubles(0).getLong(1) == 3) + assert(floats(0).getLong(1) == 3) } } From a07e614466d08453815e81b937003f1f5eba75ac Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Wed, 21 Nov 2018 11:36:39 +0000 Subject: [PATCH 6/7] style + test change --- .../src/main/java/org/apache/spark/unsafe/Platform.java | 4 ++-- .../java/org/apache/spark/unsafe/PlatformUtilSuite.java | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 46dc481a09d98..bc94f2171228a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -122,7 +122,7 @@ public static float getFloat(Object object, long offset) { public static void putFloat(Object object, long offset, float value) { if (Float.isNaN(value)) { value = Float.NaN; - } else if(value == -0.0f) { + } else if (value == -0.0f) { value = 0.0f; } _UNSAFE.putFloat(object, offset, value); @@ -135,7 +135,7 @@ public static double getDouble(Object object, long offset) { public static void putDouble(Object object, long offset, double value) { if (Double.isNaN(value)) { value = Double.NaN; - } else if(value == -0.0d) { + } else if (value == -0.0d) { value = 0.0d; } _UNSAFE.putDouble(object, offset, value); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 9257cfa46830f..ab34324eb54cc 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -24,8 +24,6 @@ import org.junit.Assert; import org.junit.Test; -import java.nio.ByteBuffer; - public class PlatformUtilSuite { @Test @@ -167,7 +165,10 @@ public void writeMinusZeroIsReplacedWithZero() { byte[] floatBytes = new byte[Float.BYTES]; Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d); Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f); - Assert.assertEquals(0, Double.compare(0.0d, ByteBuffer.wrap(doubleBytes).getDouble())); - Assert.assertEquals(0, Float.compare(0.0f, ByteBuffer.wrap(floatBytes).getFloat())); + double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET); + float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET); + + Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform)); + Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform)); } } From 03408d3d44a201040fe9996b213c6b923f1c97dc Mon Sep 17 00:00:00 2001 From: Alon Doron Date: Thu, 22 Nov 2018 13:36:08 +0000 Subject: [PATCH 7/7] fix tests --- .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 2 +- sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e2dc7ad307f60..ff64edcd07f4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -726,7 +726,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { val colName = "i" - val doubles = Seq(0.0d, 0.0d, -0.0d).toDF(colName).groupBy(colName).count().collect() + val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() assert(doubles.length == 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index baca9c1cfb9a0..8ba67239fb907 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -289,7 +289,7 @@ object QueryTest { def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null - case d: java.math.BigDecimal => BigDecimal(d) + case bd: java.math.BigDecimal => BigDecimal(bd) // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ case seq: Seq[_] => seq.map { case b: java.lang.Byte => b.byteValue @@ -303,6 +303,9 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) + // spark treats -0.0 as 0.0 + case d: Double if d == -0.0d => 0.0d + case f: Float if f == -0.0f => 0.0f case o => o }) }