diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 0fcdd420bcfe3..4e36fd45fb04e 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -25,7 +25,7 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. - - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier. - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 7553ab8cf7000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -198,46 +198,11 @@ protected final void writeLong(long offset, long value) { Platform.putLong(getBuffer(), offset, value); } - // We need to take care of NaN and -0.0 in several places: - // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be - // treated as same. - // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong - // to the same group. - // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be - // treated as same. - // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` - // should be treated as same. - // - // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we - // recursively compare the fields/elements, so it's also fine. - // - // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different - // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. - // - // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing - // float/double columns and nested fields to `UnsafeRow`. - // - // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract - // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex - // types, so nested float/double may not be normalized. We need to make sure that all the unsafe - // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during - // creation. protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } Platform.putFloat(getBuffer(), offset, value); } - // See comments for `writeFloat`. protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala new file mode 100644 index 0000000000000..520f24aa22e4c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +/** + * We need to take care of special floating numbers (NaN and -0.0) in several places: + * 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be + * treated as same. + * 2. In aggregate grouping keys, different NaNs should belong to the same group, -0.0 and 0.0 + * should belong to the same group. + * 3. In join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be + * treated as same. + * 4. In window partition keys, different NaNs should belong to the same partition, -0.0 and 0.0 + * should belong to the same partition. + * + * Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we + * recursively compare the fields/elements, so it's also fine. + * + * Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window partition keys into + * binary `UnsafeRow` and compare the binary data directly. Different NaNs have different binary + * representation, and the same thing happens for -0.0 and 0.0. + * + * This rule normalizes NaN and -0.0 in window partition keys, join keys and aggregate grouping + * keys. + * + * Ideally we should do the normalization in the physical operators that compare the + * binary `UnsafeRow` directly. We don't need this normalization if the Spark SQL execution engine + * is not optimized to run on binary data. This rule is created to simplify the implementation, so + * that we have a single place to do normalization, which is more maintainable. + * + * Note that, this rule must be executed at the end of optimizer, because the optimizer may create + * new joins(the subquery rewrite) and new join conditions(the join reorder). + */ +object NormalizeFloatingNumbers extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + // A subquery will be rewritten into join later, and will go through this rule + // eventually. Here we skip subquery, as we only need to run this rule once. + case _: Subquery => plan + + case _ => plan transform { + case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) => + // Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need + // to normalize the `windowExpressions`, as they are executed per input row and should take + // the input row as it is. + w.copy(partitionSpec = w.partitionSpec.map(normalize)) + + // Only hash join and sort merge join need the normalization. Here we catch all Joins with + // join keys, assuming Joins with join keys are always planned as hash join or sort merge + // join. It's very unlikely that we will break this assumption in the near future. + case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) + // The analyzer guarantees left and right joins keys are of the same data type. Here we + // only need to check join keys of one side. + if leftKeys.exists(k => needNormalize(k.dataType)) => + val newLeftJoinKeys = leftKeys.map(normalize) + val newRightJoinKeys = rightKeys.map(normalize) + val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { + case (l, r) => EqualTo(l, r) + } ++ condition + j.copy(condition = Some(newConditions.reduce(And))) + + // TODO: ideally Aggregate should also be handled here, but its grouping expressions are + // mixed in its aggregate expressions. It's unreliable to change the grouping expressions + // here. For now we normalize grouping expressions in `AggUtils` during planning. + } + } + + private def needNormalize(dt: DataType): Boolean = dt match { + case FloatType | DoubleType => true + case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) + case ArrayType(et, _) => needNormalize(et) + // Currently MapType is not comparable and analyzer should fail earlier if this case happens. + case _: MapType => + throw new IllegalStateException("grouping/join/window partition keys cannot be map type.") + case _ => false + } + + private[sql] def normalize(expr: Expression): Expression = expr match { + case _ if expr.dataType == FloatType || expr.dataType == DoubleType => + NormalizeNaNAndZero(expr) + + case CreateNamedStruct(children) => + CreateNamedStruct(children.map(normalize)) + + case CreateNamedStructUnsafe(children) => + CreateNamedStructUnsafe(children.map(normalize)) + + case CreateArray(children) => + CreateArray(children.map(normalize)) + + case CreateMap(children) => + CreateMap(children.map(normalize)) + + case a: Alias if needNormalize(a.dataType) => + a.withNewChildren(Seq(normalize(a.child))) + + case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) => + val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => + normalize(GetStructField(expr, i)) + } + CreateStruct(fields) + + case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) => + val ArrayType(et, containsNull) = expr.dataType + val lv = NamedLambdaVariable("arg", et, containsNull) + val function = normalize(lv) + ArrayTransform(expr, LambdaFunction(function, Seq(lv))) + + case _ => expr + } +} + +case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = child.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType)) + + private lazy val normalizer: Any => Any = child.dataType match { + case FloatType => (input: Any) => { + val f = input.asInstanceOf[Float] + if (f.isNaN) { + Float.NaN + } else if (f == -0.0f) { + 0.0f + } else { + f + } + } + + case DoubleType => (input: Any) => { + val d = input.asInstanceOf[Double] + if (d.isNaN) { + Double.NaN + } else if (d == -0.0d) { + 0.0d + } else { + d + } + } + } + + override def nullSafeEval(input: Any): Any = { + normalizer(input) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val codeToNormalize = child.dataType match { + case FloatType => (f: String) => { + s""" + |if (Float.isNaN($f)) { + | ${ev.value} = Float.NaN; + |} else if ($f == -0.0f) { + | ${ev.value} = 0.0f; + |} else { + | ${ev.value} = $f; + |} + """.stripMargin + } + + case DoubleType => (d: String) => { + s""" + |if (Double.isNaN($d)) { + | ${ev.value} = Double.NaN; + |} else if ($d == -0.0d) { + | ${ev.value} = 0.0d; + |} else { + | ${ev.value} = $d; + |} + """.stripMargin + } + } + + nullSafeCodeGen(ctx, ev, codeToNormalize) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 06f908281dd3c..d51dc6663d434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -180,7 +180,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CollapseProject, RemoveNoopOperators) :+ Batch("UpdateAttributeReferences", Once, - UpdateNullabilityInAttributeReferences) + UpdateNullabilityInAttributeReferences) :+ + // This batch must be executed after the `RewriteSubquery` batch, which creates joins. + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) } /** @@ -210,7 +212,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - PullOutPythonUDFInJoinCondition.ruleName :: Nil + PullOutPythonUDFInJoinCondition.ruleName :: + NormalizeFloatingNumbers.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index dfc3b2d22129d..95be0a52cb2ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -105,8 +105,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan, JoinHint) - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition, hint) => + def unapply(join: Join): Option[ReturnType] = join match { + case Join(left, right, joinType, condition, hint) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -140,7 +140,6 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } else { None } - case _ => None } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index ecb8047459b0c..69523fa81bc65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -246,22 +246,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testBothCodegenAndInterpreted("NaN canonicalization") { - val factory = UnsafeProjection - val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) - - val row1 = new SpecificInternalRow(fieldTypes) - row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) - row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) - - val row2 = new SpecificInternalRow(fieldTypes) - row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) - row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - - val converter = factory.create(fieldTypes) - assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) - } - testBothCodegenAndInterpreted("basic conversion with struct type") { val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index 22e1fa6dfed4f..86b8fa54c0fd4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -49,25 +49,4 @@ class UnsafeRowWriterSuite extends SparkFunSuite { // The two rows should be the equal assert(res1 == res2) } - - test("SPARK-26021: normalize float/double NaN and -0.0") { - val unsafeRowWriter1 = new UnsafeRowWriter(4) - unsafeRowWriter1.resetRowWriter() - unsafeRowWriter1.write(0, Float.NaN) - unsafeRowWriter1.write(1, Double.NaN) - unsafeRowWriter1.write(2, 0.0f) - unsafeRowWriter1.write(3, 0.0) - val res1 = unsafeRowWriter1.getRow - - val unsafeRowWriter2 = new UnsafeRowWriter(4) - unsafeRowWriter2.resetRowWriter() - unsafeRowWriter2.write(0, 0.0f/0.0f) - unsafeRowWriter2.write(1, 0.0/0.0) - unsafeRowWriter2.write(2, -0.0f) - unsafeRowWriter2.write(3, -0.0) - val res2 = unsafeRowWriter2.getRow - - // The two rows should be the equal - assert(res1 == res2) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 6be88c463dbd9..8b7556b0c6c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} -import org.apache.spark.sql.internal.SQLConf /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -35,12 +35,20 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because + // `groupingExpressions` is not extracted during logical phase. + val normalizedGroupingExpressions = groupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, + groupingExpressions = normalizedGroupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -53,7 +61,7 @@ object AggUtils { if (objectHashEnabled && useObjectHash) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, + groupingExpressions = normalizedGroupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -62,7 +70,7 @@ object AggUtils { } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, + groupingExpressions = normalizedGroupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ff64edcd07f4b..73259a0ed3b50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -724,17 +724,52 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "type: GroupBy]")) } - test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { - val colName = "i" - val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() - val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() - - assert(doubles.length == 1) - assert(floats.length == 1) - // using compare since 0.0 == -0.0 is true - assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) - assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) - assert(doubles(0).getLong(1) == 3) - assert(floats(0).getLong(1) == 3) + test("SPARK-26021: NaN and -0.0 in grouping expressions") { + import java.lang.Float.floatToRawIntBits + import java.lang.Double.doubleToRawLongBits + + // 0.0/0.0 and NaN are different values. + assert(floatToRawIntBits(0.0f/0.0f) != floatToRawIntBits(Float.NaN)) + assert(doubleToRawLongBits(0.0/0.0) != doubleToRawLongBits(Double.NaN)) + + checkAnswer( + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(), + Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil) + checkAnswer( + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d").groupBy("d").count(), + Row(0.0d, 2) :: Row(Double.NaN, 2) :: Nil) + + // test with complicated type grouping expressions + checkAnswer( + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f") + .groupBy(array("f"), struct("f")).count(), + Row(Seq(0.0f), Row(0.0f), 2) :: + Row(Seq(Float.NaN), Row(Float.NaN), 2) :: Nil) + checkAnswer( + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d") + .groupBy(array("d"), struct("d")).count(), + Row(Seq(0.0d), Row(0.0d), 2) :: + Row(Seq(Double.NaN), Row(Double.NaN), 2) :: Nil) + + checkAnswer( + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f") + .groupBy(array(struct("f")), struct(array("f"))).count(), + Row(Seq(Row(0.0f)), Row(Seq(0.0f)), 2) :: + Row(Seq(Row(Float.NaN)), Row(Seq(Float.NaN)), 2) :: Nil) + checkAnswer( + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d") + .groupBy(array(struct("d")), struct(array("d"))).count(), + Row(Seq(Row(0.0d)), Row(Seq(0.0d)), 2) :: + Row(Seq(Row(Double.NaN)), Row(Seq(Double.NaN)), 2) :: Nil) + + // test with complicated type grouping columns + val df = Seq( + (Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))), + (Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0))) + ).toDF("arr", "stru", "arrOfStru") + checkAnswer( + df.groupBy("arr", "stru", "arrOfStru").count(), + Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index a4a3e2a62d1a5..6bd12cbf0135d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,16 +295,4 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } - - test("NaN and -0.0 in join keys") { - val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") - val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") - val joined = df1.join(df2, Seq("f", "d")) - checkAnswer(joined, Seq( - Row(Float.NaN, Double.NaN), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(0.0f, 0.0))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 9277dc6859247..f4ba2f0673c0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -697,16 +697,56 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } test("NaN and -0.0 in window partition keys") { + import java.lang.Float.floatToRawIntBits + import java.lang.Double.doubleToRawLongBits + + // 0.0/0.0 and NaN are different values. + assert(floatToRawIntBits(0.0f/0.0f) != floatToRawIntBits(Float.NaN)) + assert(doubleToRawLongBits(0.0/0.0) != doubleToRawLongBits(Double.NaN)) + val df = Seq( - (Float.NaN, Double.NaN, 1), - (0.0f/0.0f, 0.0/0.0, 1), - (0.0f, 0.0, 1), - (-0.0f, -0.0, 1)).toDF("f", "d", "i") - val result = df.select($"f", count("i").over(Window.partitionBy("f", "d"))) - checkAnswer(result, Seq( - Row(Float.NaN, 2), - Row(Float.NaN, 2), - Row(0.0f, 2), - Row(0.0f, 2))) + (Float.NaN, Double.NaN), + (0.0f/0.0f, 0.0/0.0), + (0.0f, 0.0), + (-0.0f, -0.0)).toDF("f", "d") + + checkAnswer( + df.select($"f", count(lit(1)).over(Window.partitionBy("f", "d"))), + Seq( + Row(Float.NaN, 2), + Row(0.0f/0.0f, 2), + Row(0.0f, 2), + Row(-0.0f, 2))) + + // test with complicated window partition keys. + val windowSpec1 = Window.partitionBy(array("f"), struct("d")) + checkAnswer( + df.select($"f", count(lit(1)).over(windowSpec1)), + Seq( + Row(Float.NaN, 2), + Row(0.0f/0.0f, 2), + Row(0.0f, 2), + Row(-0.0f, 2))) + + val windowSpec2 = Window.partitionBy(array(struct("f")), struct(array("d"))) + checkAnswer( + df.select($"f", count(lit(1)).over(windowSpec2)), + Seq( + Row(Float.NaN, 2), + Row(0.0f/0.0f, 2), + Row(0.0f, 2), + Row(-0.0f, 2))) + + // test with df with complicated-type columns. + val df2 = Seq( + (Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))), + (Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0))) + ).toDF("arr", "stru", "arrOfStru") + val windowSpec3 = Window.partitionBy("arr", "stru", "arrOfStru") + checkAnswer( + df2.select($"arr", $"stru", $"arrOfStru", count(lit(1)).over(windowSpec3)), + Seq( + Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2), + Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 0ded5d8ce1e28..4d7037f36b1fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -397,27 +397,48 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("special floating point values") { import org.scalatest.exceptions.TestFailedException - // Spark treats -0.0 as 0.0 + // Spark distinguishes -0.0 and 0.0 intercept[TestFailedException] { - checkDataset(Seq(-0.0d).toDS(), -0.0d) + checkDataset(Seq(-0.0d).toDS(), 0.0d) } intercept[TestFailedException] { - checkDataset(Seq(-0.0f).toDS(), -0.0f) + checkAnswer(Seq(-0.0d).toDF(), Row(0.0d)) } intercept[TestFailedException] { - checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(-0.0)) + checkDataset(Seq(-0.0f).toDS(), 0.0f) } + intercept[TestFailedException] { + checkAnswer(Seq(-0.0f).toDF(), Row(0.0f)) + } + intercept[TestFailedException] { + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + } + intercept[TestFailedException] { + checkAnswer(Seq(Tuple1(-0.0)).toDF(), Row(Row(0.0))) + } + intercept[TestFailedException] { + checkDataset(Seq(Seq(-0.0)).toDS(), Seq(0.0)) + } + intercept[TestFailedException] { + checkAnswer(Seq(Seq(-0.0)).toDF(), Row(Seq(0.0))) + } + + val floats = Seq[Float](-0.0f, 0.0f, Float.NaN) + checkDataset(floats.toDS(), floats: _*) + + val arrayOfFloats = Seq[Array[Float]](Array(0.0f, -0.0f), Array(-0.0f, Float.NaN)) + checkDataset(arrayOfFloats.toDS(), arrayOfFloats: _*) - val floats = Seq[Float](-0.0f, 0.0f, Float.NaN).toDS() - checkDataset(floats, 0.0f, 0.0f, Float.NaN) + val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN) + checkDataset(doubles.toDS(), doubles: _*) - val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN).toDS() - checkDataset(doubles, 0.0, 0.0, Double.NaN) + val arrayOfDoubles = Seq[Array[Double]](Array(0.0d, -0.0d), Array(-0.0d, Double.NaN)) + checkDataset(arrayOfDoubles.toDS(), arrayOfDoubles: _*) - checkDataset(Seq(Tuple1(Float.NaN)).toDS(), Tuple1(Float.NaN)) - checkDataset(Seq(Tuple1(-0.0f)).toDS(), Tuple1(0.0f)) - checkDataset(Seq(Tuple1(Double.NaN)).toDS(), Tuple1(Double.NaN)) - checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + val tuples = Seq[(Float, Float, Double, Double)]( + (0.0f, -0.0f, 0.0d, -0.0d), + (-0.0f, Float.NaN, -0.0d, Double.NaN)) + checkDataset(tuples.toDS(), tuples: _*) val complex = Map(Array(Seq(Tuple1(Double.NaN))) -> Map(Tuple2(Float.NaN, null))) checkDataset(Seq(complex).toDS(), complex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 91445c8d96d85..81cc95847a79d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -909,4 +909,64 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 100, 42, 200, 1, 42)) } } + + test("NaN and -0.0 in join keys") { + withTempView("v1", "v2", "v3", "v4") { + Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d").createTempView("v1") + Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d").createTempView("v2") + + checkAnswer( + sql( + """ + |SELECT v1.f, v1.d, v2.f, v2.d + |FROM v1 JOIN v2 + |ON v1.f = v2.f AND v1.d = v2.d + """.stripMargin), + Seq( + Row(Float.NaN, Double.NaN, Float.NaN, Double.NaN), + Row(0.0f, 0.0, 0.0f, 0.0), + Row(0.0f, 0.0, -0.0f, -0.0), + Row(-0.0f, -0.0, 0.0f, 0.0), + Row(-0.0f, -0.0, -0.0f, -0.0))) + + // test with complicated join keys. + checkAnswer( + sql( + """ + |SELECT v1.f, v1.d, v2.f, v2.d + |FROM v1 JOIN v2 + |ON + | array(v1.f) = array(v2.f) AND + | struct(v1.d) = struct(v2.d) AND + | array(struct(v1.f, v1.d)) = array(struct(v2.f, v2.d)) AND + | struct(array(v1.f), array(v1.d)) = struct(array(v2.f), array(v2.d)) + """.stripMargin), + Seq( + Row(Float.NaN, Double.NaN, Float.NaN, Double.NaN), + Row(0.0f, 0.0, 0.0f, 0.0), + Row(0.0f, 0.0, -0.0f, -0.0), + Row(-0.0f, -0.0, 0.0f, 0.0), + Row(-0.0f, -0.0, -0.0f, -0.0))) + + // test with tables with complicated-type columns. + Seq((Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN)))) + .toDF("arr", "stru", "arrOfStru").createTempView("v3") + Seq((Array(0.0f, -0.0f), Tuple2(0.0d, 0.0/0.0), Seq(Tuple2(0.0d, 0.0/0.0)))) + .toDF("arr", "stru", "arrOfStru").createTempView("v4") + checkAnswer( + sql( + """ + |SELECT v3.arr, v3.stru, v3.arrOfStru, v4.arr, v4.stru, v4.arrOfStru + |FROM v3 JOIN v4 + |ON v3.arr = v4.arr AND v3.stru = v4.stru AND v3.arrOfStru = v4.arrOfStru + """.stripMargin), + Seq(Row( + Seq(-0.0f, 0.0f), + Row(-0.0d, Double.NaN), + Seq(Row(-0.0d, Double.NaN)), + Seq(0.0f, -0.0f), + Row(0.0d, 0.0/0.0), + Seq(Row(0.0d, 0.0/0.0))))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a547676c5ed5c..cf25f1ce910db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,7 +64,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { val result = getResult(ds) - if (!compare(result.toSeq, expectedAnswer)) { + if (!QueryTest.compare(result.toSeq, expectedAnswer)) { fail( s""" |Decoded objects do not match expected objects: @@ -84,7 +84,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { val result = getResult(ds) - if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { fail( s""" |Decoded objects do not match expected objects: @@ -124,24 +124,6 @@ abstract class QueryTest extends PlanTest { } } - private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { - case (null, null) => true - case (null, _) => false - case (_, null) => false - case (a: Array[_], b: Array[_]) => - a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Iterable[_], b: Iterable[_]) => - a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Product, b: Product) => - compare(a.productIterator.toSeq, b.productIterator.toSeq) - // 0.0 == -0.0, turn float/double to binary before comparison, to distinguish 0.0 and -0.0. - case (a: Double, b: Double) => - java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) - case (a: Float, b: Float) => - java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) - case (a, b) => a == b - } - /** * Runs the plan and makes sure the answer matches the expected result. * @@ -310,9 +292,6 @@ object QueryTest { // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq case r: Row => prepareRow(r) - // spark treats -0.0 as 0.0 - case d: Double if d == -0.0d => 0.0d - case f: Float if f == -0.0f => 0.0f case o => o }) } @@ -352,11 +331,35 @@ object QueryTest { None } + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Map[_, _], b: Map[_, _]) => + val entries1 = a.iterator.toSeq.sortBy(_.toString()) + val entries2 = b.iterator.toSeq.sortBy(_.toString()) + compare(entries1, entries2) + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: Row, b: Row) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a == b + } + def sameRows( expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], isSorted: Boolean = false): Option[String] = { - if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { return Some(genError(expectedAnswer, sparkAnswer, isSorted)) } None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index cfae2d82e273d..ecd428780c671 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -594,7 +594,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + Row(-60, 70, 101.0/9.0, 5.6, 100)) checkAnswer( spark.sql(