Skip to content

Commit

Permalink
[SPARK-32764][SQL] -0.0 should be equal to 0.0
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is a Spark 3.0 regression introduced by #26761. We missed a corner case that `java.lang.Double.compare` treats 0.0 and -0.0 as different, which breaks SQL semantic.

This PR adds back the `OrderingUtil`, to provide custom compare methods that take care of 0.0 vs -0.0

### Why are the changes needed?

Fix a correctness bug.

### Does this PR introduce _any_ user-facing change?

Yes, now `SELECT  0.0 > -0.0` returns false correctly as Spark 2.x.

### How was this patch tested?

new tests

Closes #29647 from cloud-fan/float.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
cloud-fan authored and dongjoon-hyun committed Sep 8, 2020
1 parent c8c082c commit 4144b6d
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.metrics.source.CodegenMetrics
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, SQLOrderingUtil}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -624,8 +624,12 @@ class CodegenContext extends Logging {
def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
// java boolean doesn't support > or < operator
case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
case DoubleType => s"java.lang.Double.compare($c1, $c2)"
case FloatType => s"java.lang.Float.compare($c1, $c2)"
case DoubleType =>
val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$")
s"$clsName.compareDoubles($c1, $c2)"
case FloatType =>
val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$")
s"$clsName.compareFloats($c1, $c2)"
// use c1 - c2 may overflow
case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)"
case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

object SQLOrderingUtil {

/**
* A special version of double comparison that follows SQL semantic:
* 1. NaN == NaN
* 2. NaN is greater than any non-NaN double
* 3. -0.0 == 0.0
*/
def compareDoubles(x: Double, y: Double): Int = {
if (x == y) 0 else java.lang.Double.compare(x, y)
}

/**
* A special version of float comparison that follows SQL semantic:
* 1. NaN == NaN
* 2. NaN is greater than any non-NaN float
* 3. -0.0 == 0.0
*/
def compareFloats(x: Float, y: Float): Int = {
if (x == y) 0 else java.lang.Float.compare(x, y)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag
import scala.util.Try

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil

/**
* The data type representing `Double` values. Please use the singleton `DataTypes.DoubleType`.
Expand All @@ -38,7 +39,7 @@ class DoubleType private() extends FractionalType {
private[sql] val numeric = implicitly[Numeric[Double]]
private[sql] val fractional = implicitly[Fractional[Double]]
private[sql] val ordering =
(x: Double, y: Double) => java.lang.Double.compare(x, y)
(x: Double, y: Double) => SQLOrderingUtil.compareDoubles(x, y)
private[sql] val asIntegral = DoubleType.DoubleAsIfIntegral

override private[sql] def exactNumeric = DoubleExactNumeric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag
import scala.util.Try

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.util.SQLOrderingUtil

/**
* The data type representing `Float` values. Please use the singleton `DataTypes.FloatType`.
Expand All @@ -38,7 +39,7 @@ class FloatType private() extends FractionalType {
private[sql] val numeric = implicitly[Numeric[Float]]
private[sql] val fractional = implicitly[Fractional[Float]]
private[sql] val ordering =
(x: Float, y: Float) => java.lang.Float.compare(x, y)
(x: Float, y: Float) => SQLOrderingUtil.compareFloats(x, y)
private[sql] val asIntegral = FloatType.FloatAsIfIntegral

override private[sql] def exactNumeric = FloatExactNumeric
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.types
import scala.math.Numeric._
import scala.math.Ordering

import org.apache.spark.sql.catalyst.util.SQLOrderingUtil
import org.apache.spark.sql.types.Decimal.DecimalIsConflicted

private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering {
Expand Down Expand Up @@ -148,7 +149,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional {
}
}

override def compare(x: Float, y: Float): Int = java.lang.Float.compare(x, y)
override def compare(x: Float, y: Float): Int = SQLOrderingUtil.compareFloats(x, y)
}

private[sql] object DoubleExactNumeric extends DoubleIsFractional {
Expand Down Expand Up @@ -176,7 +177,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional {
}
}

override def compare(x: Double, y: Double): Int = java.lang.Double.compare(x, y)
override def compare(x: Double, y: Double): Int = SQLOrderingUtil.compareDoubles(x, y)
}

private[sql] object DecimalExactNumeric extends DecimalIsConflicted {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,4 +538,20 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty)
checkEvaluation(inSet, false, row)
}

test("SPARK-32764: compare special double/float values") {
checkEvaluation(EqualTo(Literal(Double.NaN), Literal(Double.NaN)), true)
checkEvaluation(EqualTo(Literal(Double.NaN), Literal(Double.PositiveInfinity)), false)
checkEvaluation(EqualTo(Literal(0.0D), Literal(-0.0D)), true)
checkEvaluation(GreaterThan(Literal(Double.NaN), Literal(Double.PositiveInfinity)), true)
checkEvaluation(GreaterThan(Literal(Double.NaN), Literal(Double.NaN)), false)
checkEvaluation(GreaterThan(Literal(0.0D), Literal(-0.0D)), false)

checkEvaluation(EqualTo(Literal(Float.NaN), Literal(Float.NaN)), true)
checkEvaluation(EqualTo(Literal(Float.NaN), Literal(Float.PositiveInfinity)), false)
checkEvaluation(EqualTo(Literal(0.0F), Literal(-0.0F)), true)
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.PositiveInfinity)), true)
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false)
checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.util

import java.lang.{Double => JDouble, Float => JFloat}

import org.apache.spark.SparkFunSuite

class SQLOrderingUtilSuite extends SparkFunSuite {

test("compareDoublesSQL") {
def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
assert(SQLOrderingUtil.compareDoubles(a, b) === JDouble.compare(a, b))
assert(SQLOrderingUtil.compareDoubles(b, a) === JDouble.compare(b, a))
}
shouldMatchDefaultOrder(0d, 0d)
shouldMatchDefaultOrder(0d, 1d)
shouldMatchDefaultOrder(-1d, 1d)
shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)

val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L)
assert(JDouble.isNaN(specialNaN))
assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN))

assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.NaN) === 0)
assert(SQLOrderingUtil.compareDoubles(Double.NaN, specialNaN) === 0)
assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.PositiveInfinity) > 0)
assert(SQLOrderingUtil.compareDoubles(specialNaN, Double.PositiveInfinity) > 0)
assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.NegativeInfinity) > 0)
assert(SQLOrderingUtil.compareDoubles(Double.PositiveInfinity, Double.NaN) < 0)
assert(SQLOrderingUtil.compareDoubles(Double.NegativeInfinity, Double.NaN) < 0)
assert(SQLOrderingUtil.compareDoubles(0.0d, -0.0d) === 0)
assert(SQLOrderingUtil.compareDoubles(-0.0d, 0.0d) === 0)
}

test("compareFloatsSQL") {
def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
assert(SQLOrderingUtil.compareFloats(a, b) === JFloat.compare(a, b))
assert(SQLOrderingUtil.compareFloats(b, a) === JFloat.compare(b, a))
}
shouldMatchDefaultOrder(0f, 0f)
shouldMatchDefaultOrder(0f, 1f)
shouldMatchDefaultOrder(-1f, 1f)
shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)

val specialNaN = JFloat.intBitsToFloat(-6966608)
assert(JFloat.isNaN(specialNaN))
assert(JFloat.floatToRawIntBits(Float.NaN) != JFloat.floatToRawIntBits(specialNaN))

assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.NaN) === 0)
assert(SQLOrderingUtil.compareDoubles(Float.NaN, specialNaN) === 0)
assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.PositiveInfinity) > 0)
assert(SQLOrderingUtil.compareDoubles(specialNaN, Float.PositiveInfinity) > 0)
assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.NegativeInfinity) > 0)
assert(SQLOrderingUtil.compareDoubles(Float.PositiveInfinity, Float.NaN) < 0)
assert(SQLOrderingUtil.compareDoubles(Float.NegativeInfinity, Float.NaN) < 0)
assert(SQLOrderingUtil.compareDoubles(0.0f, -0.0f) === 0)
assert(SQLOrderingUtil.compareDoubles(-0.0f, 0.0f) === 0)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,11 @@ class DataFrameSuite extends QueryTest
test("SPARK-32761: aggregating multiple distinct CONSTANT columns") {
checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1))
}

test("SPARK-32764: -0.0 and 0.0 should be equal") {
val df = Seq(0.0 -> -0.0).toDF("pos", "neg")
checkAnswer(df.select($"pos" > $"neg"), Row(false))
}
}

case class GroupByKey(a: Int, b: Int)

0 comments on commit 4144b6d

Please sign in to comment.