From 0eb17811591974ddd5cb59f0e06e573fc4ba5ff0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Dec 2018 02:14:55 +0800 Subject: [PATCH 1/7] retain the difference between 0.0 and -0.0 --- .../expressions/codegen/UnsafeWriter.java | 35 ---- .../optimizer/NormalizeFloatingNumbers.scala | 179 ++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 4 +- .../sql/catalyst/planning/patterns.scala | 5 +- .../codegen/UnsafeRowWriterSuite.scala | 21 -- .../sql/execution/aggregate/AggUtils.scala | 16 +- .../spark/sql/DataFrameAggregateSuite.scala | 36 ++-- .../apache/spark/sql/DataFrameJoinSuite.scala | 12 -- .../sql/DataFrameWindowFunctionsSuite.scala | 41 +++- .../spark/sql/DatasetPrimitiveSuite.scala | 39 ++-- .../org/apache/spark/sql/JoinSuite.scala | 46 +++++ .../org/apache/spark/sql/QueryTest.scala | 20 +- 12 files changed, 334 insertions(+), 120 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java index 7553ab8cf7000..95263a0da95a8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java @@ -198,46 +198,11 @@ protected final void writeLong(long offset, long value) { Platform.putLong(getBuffer(), offset, value); } - // We need to take care of NaN and -0.0 in several places: - // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be - // treated as same. - // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong - // to the same group. - // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be - // treated as same. - // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` - // should be treated as same. - // - // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we - // recursively compare the fields/elements, so it's also fine. - // - // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different - // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. - // - // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing - // float/double columns and nested fields to `UnsafeRow`. - // - // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract - // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex - // types, so nested float/double may not be normalized. We need to make sure that all the unsafe - // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during - // creation. protected final void writeFloat(long offset, float value) { - if (Float.isNaN(value)) { - value = Float.NaN; - } else if (value == -0.0f) { - value = 0.0f; - } Platform.putFloat(getBuffer(), offset, value); } - // See comments for `writeFloat`. protected final void writeDouble(long offset, double value) { - if (Double.isNaN(value)) { - value = Double.NaN; - } else if (value == -0.0d) { - value = 0.0d; - } Platform.putDouble(getBuffer(), offset, value); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala new file mode 100644 index 0000000000000..bded7e7d60c15 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +/** + * We need to take care of special floating numbers (NaN and -0.0) in several places: + * 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be + * treated as same. + * 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong + * to the same group. + * 3. In join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be + * treated as same. + * 4. In window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` + * should be treated as same. + * + * Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we + * recursively compare the fields/elements, so it's also fine. + * + * Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different + * NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. + * + * This rule normalizes NaN and -0.0 in Window partition keys, Join keys and Aggregate grouping + * expressions. + */ +object NormalizeFloatingNumbers extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + // A subquery will be rewritten into join later, and will go through this rule + // eventually. Here we skip subquery, as we only need to run this rule once. + case _: Subquery => plan + + case _ => plan transform { + case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) => + w.copy(partitionSpec = w.partitionSpec.map(normalize)) + + case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _) + if leftKeys.exists(k => needNormalize(k.dataType)) => + val newLeftJoinKeys = leftKeys.map(normalize) + val newRightJoinKeys = rightKeys.map(normalize) + val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { + case (l, r) => EqualTo(l, r) + } ++ condition + j.copy(condition = Some(newConditions.reduce(And))) + + // TODO: ideally Aggregate should also be handled here, but its grouping expressions are + // mixed in its aggregate expressions. It's unreliable to change the grouping expressions + // here. For now we normalize grouping expressions in `AggUtils` during planning. + } + } + + private def needNormalize(dt: DataType): Boolean = dt match { + case FloatType | DoubleType => true + case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) + case ArrayType(et, _) => needNormalize(et) + // We don't need to handle MapType here, as it's not comparable. + case _ => false + } + + private[sql] def normalize(expr: Expression): Expression = expr match { + case _ if expr.dataType == FloatType || expr.dataType == DoubleType => + NormalizeNaNAndZero(expr) + + case CreateNamedStruct(children) => + CreateNamedStruct(children.map(normalize)) + + case CreateNamedStructUnsafe(children) => + CreateNamedStructUnsafe(children.map(normalize)) + + case CreateArray(children) => + CreateArray(children.map(normalize)) + + case CreateMap(children) => + CreateMap(children.map(normalize)) + + case a: Alias if needNormalize(a.dataType) => + a.withNewChildren(Seq(normalize(a.child))) + + case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) => + val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => + normalize(GetStructField(expr, i)) + } + CreateStruct(fields) + + case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) => + val ArrayType(et, containsNull) = expr.dataType + val lv = NamedLambdaVariable("arg", et, containsNull) + val function = normalize(lv) + ArrayTransform(expr, LambdaFunction(function, Seq(lv))) + + case _ => expr + } +} + +case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = child.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType)) + + private lazy val normalizer: Any => Any = child.dataType match { + case FloatType => (input: Any) => { + val f = input.asInstanceOf[Float] + if (f.isNaN) { + Float.NaN + } else if (f == -0.0f) { + 0.0f + } else { + f + } + } + + case DoubleType => (input: Any) => { + val d = input.asInstanceOf[Double] + if (d.isNaN) { + Double.NaN + } else if (d == -0.0d) { + 0.0d + } else { + d + } + } + } + + override def nullSafeEval(input: Any): Any = { + normalizer(input) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val codeToNormalize = child.dataType match { + case FloatType => (f: String) => { + s""" + |if (Float.isNaN($f)) { + | ${ev.value} = Float.NaN; + |} else if ($f == -0.0f) { + | ${ev.value} = 0.0f; + |} else { + | ${ev.value} = $f; + |} + """.stripMargin + } + + case DoubleType => (d: String) => { + s""" + |if (Double.isNaN($d)) { + | ${ev.value} = Double.NaN; + |} else if ($d == -0.0d) { + | ${ev.value} = 0.0d; + |} else { + | ${ev.value} = $d; + |} + """.stripMargin + } + } + + nullSafeCodeGen(ctx, ev, codeToNormalize) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 06f908281dd3c..a3740a99e31f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -180,7 +180,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) CollapseProject, RemoveNoopOperators) :+ Batch("UpdateAttributeReferences", Once, - UpdateNullabilityInAttributeReferences) + UpdateNullabilityInAttributeReferences) :+ + // This batch must be executed after the `RewriteSubquery` batch, which creates joins. + Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index dfc3b2d22129d..95be0a52cb2ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -105,8 +105,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan, JoinHint) - def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { - case join @ Join(left, right, joinType, condition, hint) => + def unapply(join: Join): Option[ReturnType] = join match { + case Join(left, right, joinType, condition, hint) => logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. @@ -140,7 +140,6 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } else { None } - case _ => None } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala index 22e1fa6dfed4f..86b8fa54c0fd4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala @@ -49,25 +49,4 @@ class UnsafeRowWriterSuite extends SparkFunSuite { // The two rows should be the equal assert(res1 == res2) } - - test("SPARK-26021: normalize float/double NaN and -0.0") { - val unsafeRowWriter1 = new UnsafeRowWriter(4) - unsafeRowWriter1.resetRowWriter() - unsafeRowWriter1.write(0, Float.NaN) - unsafeRowWriter1.write(1, Double.NaN) - unsafeRowWriter1.write(2, 0.0f) - unsafeRowWriter1.write(3, 0.0) - val res1 = unsafeRowWriter1.getRow - - val unsafeRowWriter2 = new UnsafeRowWriter(4) - unsafeRowWriter2.resetRowWriter() - unsafeRowWriter2.write(0, 0.0f/0.0f) - unsafeRowWriter2.write(1, 0.0/0.0) - unsafeRowWriter2.write(2, -0.0f) - unsafeRowWriter2.write(3, -0.0) - val res2 = unsafeRowWriter2.getRow - - // The two rows should be the equal - assert(res1 == res2) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 6be88c463dbd9..8b7556b0c6c5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} -import org.apache.spark.sql.internal.SQLConf /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -35,12 +35,20 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because + // `groupingExpressions` is not extracted during logical phase. + val normalizedGroupingExpressions = groupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, + groupingExpressions = normalizedGroupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -53,7 +61,7 @@ object AggUtils { if (objectHashEnabled && useObjectHash) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, + groupingExpressions = normalizedGroupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -62,7 +70,7 @@ object AggUtils { } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, + groupingExpressions = normalizedGroupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ff64edcd07f4b..b1b115e957fe5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -724,17 +724,29 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "type: GroupBy]")) } - test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") { - val colName = "i" - val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect() - val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect() - - assert(doubles.length == 1) - assert(floats.length == 1) - // using compare since 0.0 == -0.0 is true - assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0) - assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0) - assert(doubles(0).getLong(1) == 3) - assert(floats(0).getLong(1) == 3) + test("SPARK-26021: NaN and -0.0 in grouping expressions") { + checkAnswer( + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(), + Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil) + checkAnswer( + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d").groupBy("d").count(), + Row(0.0d, 2) :: Row(Double.NaN, 2) :: Nil) + + // test with complicated type grouping expressions + checkAnswer( + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy(array("f"), struct("f")).count(), + Row(Seq(0.0f), Row(0.0f), 2) :: Row(Seq(Float.NaN), Row(Float.NaN), 2) :: Nil) + checkAnswer( + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d").groupBy(array("d"), struct("d")).count(), + Row(Seq(0.0d), Row(0.0d), 2) :: Row(Seq(Double.NaN), Row(Double.NaN), 2) :: Nil) + + // test with complicated type grouping columns + val df = Seq( + Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN), + Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru") + checkAnswer( + df.groupBy("arr", "stru").count(), + Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), 2) + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index a4a3e2a62d1a5..6bd12cbf0135d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -295,16 +295,4 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan } } - - test("NaN and -0.0 in join keys") { - val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") - val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d") - val joined = df1.join(df2, Seq("f", "d")) - checkAnswer(joined, Seq( - Row(Float.NaN, Double.NaN), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(0.0f, 0.0))) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 9277dc6859247..6c36163489c24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -698,15 +698,36 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { test("NaN and -0.0 in window partition keys") { val df = Seq( - (Float.NaN, Double.NaN, 1), - (0.0f/0.0f, 0.0/0.0, 1), - (0.0f, 0.0, 1), - (-0.0f, -0.0, 1)).toDF("f", "d", "i") - val result = df.select($"f", count("i").over(Window.partitionBy("f", "d"))) - checkAnswer(result, Seq( - Row(Float.NaN, 2), - Row(Float.NaN, 2), - Row(0.0f, 2), - Row(0.0f, 2))) + (Float.NaN, Double.NaN), + (0.0f/0.0f, 0.0/0.0), + (0.0f, 0.0), + (-0.0f, -0.0)).toDF("f", "d") + + checkAnswer( + df.select($"f", count(lit(1)).over(Window.partitionBy("f", "d"))), + Seq( + Row(Float.NaN, 2), + Row(0.0f/0.0f, 2), + Row(0.0f, 2), + Row(-0.0f, 2))) + + // test with complicated window partition keys. + checkAnswer( + df.select($"f", count(lit(1)).over(Window.partitionBy(array("f"), struct("d")))), + Seq( + Row(Float.NaN, 2), + Row(0.0f/0.0f, 2), + Row(0.0f, 2), + Row(-0.0f, 2))) + + // test with df with complicated-type columns. + val df2 = Seq( + Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN), + Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru") + checkAnswer( + df2.select($"arr", $"stru", count(lit(1)).over(Window.partitionBy("arr", "stru"))), + Seq( + Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), 2), + Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), 2))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 0ded5d8ce1e28..30f4f048295cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -397,27 +397,42 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { test("special floating point values") { import org.scalatest.exceptions.TestFailedException - // Spark treats -0.0 as 0.0 + // Spark distinguishes -0.0 and 0.0 intercept[TestFailedException] { - checkDataset(Seq(-0.0d).toDS(), -0.0d) + checkDataset(Seq(-0.0d).toDS(), 0.0d) } intercept[TestFailedException] { - checkDataset(Seq(-0.0f).toDS(), -0.0f) + checkAnswer(Seq(-0.0d).toDF(), Row(0.0d)) } intercept[TestFailedException] { - checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(-0.0)) + checkDataset(Seq(-0.0f).toDS(), 0.0f) } + intercept[TestFailedException] { + checkAnswer(Seq(-0.0f).toDF(), Row(0.0f)) + } + intercept[TestFailedException] { + checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + } + intercept[TestFailedException] { + checkAnswer(Seq(Tuple1(-0.0)).toDF(), Row(Row(0.0))) + } + + val floats = Seq[Float](-0.0f, 0.0f, Float.NaN) + checkDataset(floats.toDS(), floats: _*) + + val arrayOfFloats = Seq[Array[Float]](Array(0.0f, -0.0f), Array(-0.0f, Float.NaN)) + checkDataset(arrayOfFloats.toDS(), arrayOfFloats: _*) - val floats = Seq[Float](-0.0f, 0.0f, Float.NaN).toDS() - checkDataset(floats, 0.0f, 0.0f, Float.NaN) + val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN) + checkDataset(doubles.toDS(), doubles: _*) - val doubles = Seq[Double](-0.0d, 0.0d, Double.NaN).toDS() - checkDataset(doubles, 0.0, 0.0, Double.NaN) + val arrayOfDoubles = Seq[Array[Double]](Array(0.0d, -0.0d), Array(-0.0d, Double.NaN)) + checkDataset(arrayOfDoubles.toDS(), arrayOfDoubles: _*) - checkDataset(Seq(Tuple1(Float.NaN)).toDS(), Tuple1(Float.NaN)) - checkDataset(Seq(Tuple1(-0.0f)).toDS(), Tuple1(0.0f)) - checkDataset(Seq(Tuple1(Double.NaN)).toDS(), Tuple1(Double.NaN)) - checkDataset(Seq(Tuple1(-0.0)).toDS(), Tuple1(0.0)) + val tuples = Seq[(Float, Float, Double, Double)]( + (0.0f, -0.0f, 0.0d, -0.0d), + (-0.0f, Float.NaN, -0.0d, Double.NaN)) + checkDataset(tuples.toDS(), tuples: _*) val complex = Map(Array(Seq(Tuple1(Double.NaN))) -> Map(Tuple2(Float.NaN, null))) checkDataset(Seq(complex).toDS(), complex) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 91445c8d96d85..887d9adbce10a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -909,4 +909,50 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 100, 42, 200, 1, 42)) } } + + test("NaN and -0.0 in join keys") { + withTempView("v1", "v2", "v3", "v4") { + Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d").createTempView("v1") + Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d").createTempView("v2") + + checkAnswer( + sql( + """ + |SELECT v1.f, v1.d + |FROM v1 JOIN v2 + |ON v1.f = v2.f AND v1.d = v2.d + """.stripMargin), + Seq( + Row(Float.NaN, Double.NaN), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(-0.0f, -0.0), + Row(-0.0f, -0.0))) + + // test with complicated join keys. + checkAnswer( + sql( + """ + |SELECT v1.f, v1.d + |FROM v1 JOIN v2 + |ON array(v1.f) = array(v2.f) AND struct(v1.d) = struct(v2.d) + """.stripMargin), + Seq( + Row(Float.NaN, Double.NaN), + Row(0.0f, 0.0), + Row(0.0f, 0.0), + Row(-0.0f, -0.0), + Row(-0.0f, -0.0))) + + // test with tables with complicated-type columns. + Seq(Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN)).toDF("arr", "stru").createTempView("v3") + Seq(Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru").createTempView("v4") + checkAnswer( + sql("SELECT v3.arr, v3.stru FROM v3 JOIN v4 ON v3.arr = v4.arr AND v3.stru = v4.stru"), + Seq(Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN)))) + checkAnswer( + sql("SELECT v4.arr, v4.stru FROM v3 JOIN v4 ON v3.arr = v4.arr AND v3.stru = v4.stru"), + Seq(Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN)))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a547676c5ed5c..3df0612244ff0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -134,7 +134,7 @@ abstract class QueryTest extends PlanTest { a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Product, b: Product) => compare(a.productIterator.toSeq, b.productIterator.toSeq) - // 0.0 == -0.0, turn float/double to binary before comparison, to distinguish 0.0 and -0.0. + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. case (a: Double, b: Double) => java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) case (a: Float, b: Float) => @@ -282,18 +282,18 @@ object QueryTest { } - def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + def prepareAnswer(answer: Seq[Row], isSorted: Boolean, forPrint: Boolean = false): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map(prepareRow) + val converted: Seq[Row] = answer.map(prepareRow(_, forPrint)) if (!isSorted) converted.sortBy(_.toString()) else converted } // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row): Row = { + def prepareRow(row: Row, forPrint: Boolean): Row = { Row.fromSeq(row.toSeq.map { case null => null case bd: java.math.BigDecimal => BigDecimal(bd) @@ -309,10 +309,10 @@ object QueryTest { } // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq - case r: Row => prepareRow(r) - // spark treats -0.0 as 0.0 - case d: Double if d == -0.0d => 0.0d - case f: Float if f == -0.0f => 0.0f + case r: Row => prepareRow(r, forPrint) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case f: Float if !forPrint => java.lang.Float.floatToRawIntBits(f) + case d: Double if !forPrint => java.lang.Double.doubleToLongBits(d) case o => o }) } @@ -335,10 +335,10 @@ object QueryTest { sideBySide( s"== Correct Answer - ${expectedAnswer.size} ==" +: getRowType(expectedAnswer.headOption) +: - prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + prepareAnswer(expectedAnswer, isSorted, forPrint = true).map(_.toString()), s"== Spark Answer - ${sparkAnswer.size} ==" +: getRowType(sparkAnswer.headOption) +: - prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") + prepareAnswer(sparkAnswer, isSorted, forPrint = true).map(_.toString())).mkString("\n") } """.stripMargin } From ee5a1f0a60ef022228382fa59447418c4b8b26d5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Dec 2018 12:09:03 +0800 Subject: [PATCH 2/7] fix tests --- .../expressions/UnsafeRowConverterSuite.scala | 16 ---------------- .../hive/execution/AggregationQuerySuite.scala | 2 +- .../spark/sql/hive/execution/HiveUDFSuite.scala | 10 ++++++---- .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 4 files changed, 8 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index ecb8047459b0c..69523fa81bc65 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -246,22 +246,6 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB // assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } - testBothCodegenAndInterpreted("NaN canonicalization") { - val factory = UnsafeProjection - val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) - - val row1 = new SpecificInternalRow(fieldTypes) - row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) - row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) - - val row2 = new SpecificInternalRow(fieldTypes) - row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) - row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) - - val converter = factory.create(fieldTypes) - assert(converter.apply(row1).getBytes === converter.apply(row2).getBytes) - } - testBothCodegenAndInterpreted("basic conversion with struct type") { val factory = UnsafeProjection val fieldTypes: Array[DataType] = Array( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index cfae2d82e273d..ecd428780c671 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -594,7 +594,7 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te | max(distinct value1) |FROM agg2 """.stripMargin), - Row(-60, 70.0, 101.0/9.0, 5.6, 100)) + Row(-60, 70, 101.0/9.0, 5.6, 100)) checkAnswer( spark.sql( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index a6fc744cc8b5a..a9b44d08c4dd5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -141,11 +141,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("SPARK-2693 udaf aggregates test") { - checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src").collect().toSeq) + checkAnswer( + sql("SELECT percentile(key, 1) FROM src LIMIT 1"), + sql("SELECT double(max(key)) FROM src")) - checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), - sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) + checkAnswer( + sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + sql("SELECT array(double(max(key)), double(max(key))) FROM src")) } test("SPARK-16228 Percentile needs explicit cast to double") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 70efad103d13e..361e3fa06211b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1757,7 +1757,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("insert into tbl values ('3', '2.3')") checkAnswer( sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"), - Row(204.0) + Row(BigDecimal(204.0)) ) checkAnswer( sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"), From d3c59927b8195a10e8b70afc9720132c211d1113 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 27 Dec 2018 21:35:25 +0800 Subject: [PATCH 3/7] add back migration guide --- docs/sql-migration-guide-upgrade.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 0fcdd420bcfe3..4e36fd45fb04e 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -25,7 +25,7 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, `Dataset.groupByKey` results to a grouped dataset with key attribute wrongly named as "value", if the key is non-struct type, e.g. int, string, array, etc. This is counterintuitive and makes the schema of aggregation queries weird. For example, the schema of `ds.groupByKey(...).count()` is `(value, count)`. Since Spark 3.0, we name the grouping attribute to "key". The old behaviour is preserved under a newly added configuration `spark.sql.legacy.dataset.nameNonStructGroupingKeyAsValue` with a default value of `false`. - - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier. - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined. From 74934dad409b9296067981e528fd564c987e03dd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 28 Dec 2018 11:02:00 +0800 Subject: [PATCH 4/7] fix test --- .../spark/sql/DatasetPrimitiveSuite.scala | 6 ++ .../org/apache/spark/sql/QueryTest.scala | 59 +++++++++---------- .../sql/hive/execution/HiveUDFSuite.scala | 10 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index 30f4f048295cd..4d7037f36b1fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -416,6 +416,12 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { intercept[TestFailedException] { checkAnswer(Seq(Tuple1(-0.0)).toDF(), Row(Row(0.0))) } + intercept[TestFailedException] { + checkDataset(Seq(Seq(-0.0)).toDS(), Seq(0.0)) + } + intercept[TestFailedException] { + checkAnswer(Seq(Seq(-0.0)).toDF(), Row(Seq(0.0))) + } val floats = Seq[Float](-0.0f, 0.0f, Float.NaN) checkDataset(floats.toDS(), floats: _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 3df0612244ff0..2a6fa64b3e7ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,7 +64,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { val result = getResult(ds) - if (!compare(result.toSeq, expectedAnswer)) { + if (!QueryTest.compare(result.toSeq, expectedAnswer)) { fail( s""" |Decoded objects do not match expected objects: @@ -84,7 +84,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { val result = getResult(ds) - if (!compare(result.toSeq.sorted, expectedAnswer.sorted)) { + if (!QueryTest.compare(result.toSeq.sorted, expectedAnswer.sorted)) { fail( s""" |Decoded objects do not match expected objects: @@ -124,24 +124,6 @@ abstract class QueryTest extends PlanTest { } } - private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { - case (null, null) => true - case (null, _) => false - case (_, null) => false - case (a: Array[_], b: Array[_]) => - a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Iterable[_], b: Iterable[_]) => - a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} - case (a: Product, b: Product) => - compare(a.productIterator.toSeq, b.productIterator.toSeq) - // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. - case (a: Double, b: Double) => - java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) - case (a: Float, b: Float) => - java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) - case (a, b) => a == b - } - /** * Runs the plan and makes sure the answer matches the expected result. * @@ -282,18 +264,18 @@ object QueryTest { } - def prepareAnswer(answer: Seq[Row], isSorted: Boolean, forPrint: Boolean = false): Seq[Row] = { + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. // For BigDecimal type, the Scala type has a better definition of equality test (similar to // Java's java.math.BigDecimal.compareTo). // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for // equality test. - val converted: Seq[Row] = answer.map(prepareRow(_, forPrint)) + val converted: Seq[Row] = answer.map(prepareRow) if (!isSorted) converted.sortBy(_.toString()) else converted } // We need to call prepareRow recursively to handle schemas with struct types. - def prepareRow(row: Row, forPrint: Boolean): Row = { + def prepareRow(row: Row): Row = { Row.fromSeq(row.toSeq.map { case null => null case bd: java.math.BigDecimal => BigDecimal(bd) @@ -309,10 +291,7 @@ object QueryTest { } // Convert array to Seq for easy equality check. case b: Array[_] => b.toSeq - case r: Row => prepareRow(r, forPrint) - // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. - case f: Float if !forPrint => java.lang.Float.floatToRawIntBits(f) - case d: Double if !forPrint => java.lang.Double.doubleToLongBits(d) + case r: Row => prepareRow(r) case o => o }) } @@ -335,10 +314,10 @@ object QueryTest { sideBySide( s"== Correct Answer - ${expectedAnswer.size} ==" +: getRowType(expectedAnswer.headOption) +: - prepareAnswer(expectedAnswer, isSorted, forPrint = true).map(_.toString()), + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), s"== Spark Answer - ${sparkAnswer.size} ==" +: getRowType(sparkAnswer.headOption) +: - prepareAnswer(sparkAnswer, isSorted, forPrint = true).map(_.toString())).mkString("\n") + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n") } """.stripMargin } @@ -352,11 +331,31 @@ object QueryTest { None } + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: Row, b: Row) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a == b + } + def sameRows( expectedAnswer: Seq[Row], sparkAnswer: Seq[Row], isSorted: Boolean = false): Option[String] = { - if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { return Some(genError(expectedAnswer, sparkAnswer, isSorted)) } None diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index a9b44d08c4dd5..a6fc744cc8b5a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -141,13 +141,11 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("SPARK-2693 udaf aggregates test") { - checkAnswer( - sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT double(max(key)) FROM src")) + checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) - checkAnswer( - sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), - sql("SELECT array(double(max(key)), double(max(key))) FROM src")) + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } test("SPARK-16228 Percentile needs explicit cast to double") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 361e3fa06211b..70efad103d13e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1757,7 +1757,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql("insert into tbl values ('3', '2.3')") checkAnswer( sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"), - Row(BigDecimal(204.0)) + Row(204.0) ) checkAnswer( sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"), From fdc998872bbdf64a1a6221bf5a7db0ec80e12d02 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 30 Dec 2018 13:45:35 +0800 Subject: [PATCH 5/7] add comment --- .../sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index bded7e7d60c15..a9468e15cf98b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -43,6 +43,11 @@ import org.apache.spark.sql.types._ * * This rule normalizes NaN and -0.0 in Window partition keys, Join keys and Aggregate grouping * expressions. + * + * Note that, this rule should be an analyzer rule, as it must be applied to make the query result + * corrected. Currently it's executed as an optimizer rule, because the optimizer may create new + * joins(for subquery) and reorder joins(may change the join condition), and this rule needs to be + * executed at the end. */ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { From 8dafc6417e0e84d0cf1caaf3b48b9e5c4fa70e14 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 2 Jan 2019 10:43:16 +0800 Subject: [PATCH 6/7] updare --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 3 ++- sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a3740a99e31f0..d51dc6663d434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -212,7 +212,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - PullOutPythonUDFInJoinCondition.ruleName :: Nil + PullOutPythonUDFInJoinCondition.ruleName :: + NormalizeFloatingNumbers.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 2a6fa64b3e7ea..cf25f1ce910db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -337,6 +337,10 @@ object QueryTest { case (_, null) => false case (a: Array[_], b: Array[_]) => a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r)} + case (a: Map[_, _], b: Map[_, _]) => + val entries1 = a.iterator.toSeq.sortBy(_.toString()) + val entries2 = b.iterator.toSeq.sortBy(_.toString()) + compare(entries1, entries2) case (a: Iterable[_], b: Iterable[_]) => a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r)} case (a: Product, b: Product) => From 3e8c1713ed86ae45aac5c21a97cd9f44f2059e0d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Jan 2019 20:07:00 +0800 Subject: [PATCH 7/7] address comments from Bart --- .../optimizer/NormalizeFloatingNumbers.scala | 44 +++++++++------ .../spark/sql/DataFrameAggregateSuite.scala | 39 +++++++++++--- .../sql/DataFrameWindowFunctionsSuite.scala | 31 ++++++++--- .../org/apache/spark/sql/JoinSuite.scala | 54 ++++++++++++------- 4 files changed, 119 insertions(+), 49 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index a9468e15cf98b..520f24aa22e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -28,26 +28,30 @@ import org.apache.spark.sql.types._ * We need to take care of special floating numbers (NaN and -0.0) in several places: * 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be * treated as same. - * 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong - * to the same group. + * 2. In aggregate grouping keys, different NaNs should belong to the same group, -0.0 and 0.0 + * should belong to the same group. * 3. In join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be * treated as same. - * 4. In window partition keys, different NaNs should be treated as same, `-0.0` and `0.0` - * should be treated as same. + * 4. In window partition keys, different NaNs should belong to the same partition, -0.0 and 0.0 + * should belong to the same partition. * * Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we * recursively compare the fields/elements, so it's also fine. * - * Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different - * NaNs have different binary representation, and the same thing happens for -0.0 and 0.0. + * Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window partition keys into + * binary `UnsafeRow` and compare the binary data directly. Different NaNs have different binary + * representation, and the same thing happens for -0.0 and 0.0. * - * This rule normalizes NaN and -0.0 in Window partition keys, Join keys and Aggregate grouping - * expressions. + * This rule normalizes NaN and -0.0 in window partition keys, join keys and aggregate grouping + * keys. * - * Note that, this rule should be an analyzer rule, as it must be applied to make the query result - * corrected. Currently it's executed as an optimizer rule, because the optimizer may create new - * joins(for subquery) and reorder joins(may change the join condition), and this rule needs to be - * executed at the end. + * Ideally we should do the normalization in the physical operators that compare the + * binary `UnsafeRow` directly. We don't need this normalization if the Spark SQL execution engine + * is not optimized to run on binary data. This rule is created to simplify the implementation, so + * that we have a single place to do normalization, which is more maintainable. + * + * Note that, this rule must be executed at the end of optimizer, because the optimizer may create + * new joins(the subquery rewrite) and new join conditions(the join reorder). */ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { @@ -58,10 +62,18 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _ => plan transform { case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) => + // Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need + // to normalize the `windowExpressions`, as they are executed per input row and should take + // the input row as it is. w.copy(partitionSpec = w.partitionSpec.map(normalize)) - case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _) - if leftKeys.exists(k => needNormalize(k.dataType)) => + // Only hash join and sort merge join need the normalization. Here we catch all Joins with + // join keys, assuming Joins with join keys are always planned as hash join or sort merge + // join. It's very unlikely that we will break this assumption in the near future. + case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) + // The analyzer guarantees left and right joins keys are of the same data type. Here we + // only need to check join keys of one side. + if leftKeys.exists(k => needNormalize(k.dataType)) => val newLeftJoinKeys = leftKeys.map(normalize) val newRightJoinKeys = rightKeys.map(normalize) val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { @@ -79,7 +91,9 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) - // We don't need to handle MapType here, as it's not comparable. + // Currently MapType is not comparable and analyzer should fail earlier if this case happens. + case _: MapType => + throw new IllegalStateException("grouping/join/window partition keys cannot be map type.") case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b1b115e957fe5..73259a0ed3b50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -725,6 +725,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } test("SPARK-26021: NaN and -0.0 in grouping expressions") { + import java.lang.Float.floatToRawIntBits + import java.lang.Double.doubleToRawLongBits + + // 0.0/0.0 and NaN are different values. + assert(floatToRawIntBits(0.0f/0.0f) != floatToRawIntBits(Float.NaN)) + assert(doubleToRawLongBits(0.0/0.0) != doubleToRawLongBits(Double.NaN)) + checkAnswer( Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy("f").count(), Row(0.0f, 2) :: Row(Float.NaN, 2) :: Nil) @@ -734,19 +741,35 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { // test with complicated type grouping expressions checkAnswer( - Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f").groupBy(array("f"), struct("f")).count(), - Row(Seq(0.0f), Row(0.0f), 2) :: Row(Seq(Float.NaN), Row(Float.NaN), 2) :: Nil) + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f") + .groupBy(array("f"), struct("f")).count(), + Row(Seq(0.0f), Row(0.0f), 2) :: + Row(Seq(Float.NaN), Row(Float.NaN), 2) :: Nil) + checkAnswer( + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d") + .groupBy(array("d"), struct("d")).count(), + Row(Seq(0.0d), Row(0.0d), 2) :: + Row(Seq(Double.NaN), Row(Double.NaN), 2) :: Nil) + + checkAnswer( + Seq(0.0f, -0.0f, 0.0f/0.0f, Float.NaN).toDF("f") + .groupBy(array(struct("f")), struct(array("f"))).count(), + Row(Seq(Row(0.0f)), Row(Seq(0.0f)), 2) :: + Row(Seq(Row(Float.NaN)), Row(Seq(Float.NaN)), 2) :: Nil) checkAnswer( - Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d").groupBy(array("d"), struct("d")).count(), - Row(Seq(0.0d), Row(0.0d), 2) :: Row(Seq(Double.NaN), Row(Double.NaN), 2) :: Nil) + Seq(0.0d, -0.0d, 0.0d/0.0d, Double.NaN).toDF("d") + .groupBy(array(struct("d")), struct(array("d"))).count(), + Row(Seq(Row(0.0d)), Row(Seq(0.0d)), 2) :: + Row(Seq(Row(Double.NaN)), Row(Seq(Double.NaN)), 2) :: Nil) // test with complicated type grouping columns val df = Seq( - Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN), - Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru") + (Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))), + (Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0))) + ).toDF("arr", "stru", "arrOfStru") checkAnswer( - df.groupBy("arr", "stru").count(), - Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), 2) + df.groupBy("arr", "stru", "arrOfStru").count(), + Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 6c36163489c24..f4ba2f0673c0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -697,6 +697,13 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } test("NaN and -0.0 in window partition keys") { + import java.lang.Float.floatToRawIntBits + import java.lang.Double.doubleToRawLongBits + + // 0.0/0.0 and NaN are different values. + assert(floatToRawIntBits(0.0f/0.0f) != floatToRawIntBits(Float.NaN)) + assert(doubleToRawLongBits(0.0/0.0) != doubleToRawLongBits(Double.NaN)) + val df = Seq( (Float.NaN, Double.NaN), (0.0f/0.0f, 0.0/0.0), @@ -712,8 +719,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row(-0.0f, 2))) // test with complicated window partition keys. + val windowSpec1 = Window.partitionBy(array("f"), struct("d")) + checkAnswer( + df.select($"f", count(lit(1)).over(windowSpec1)), + Seq( + Row(Float.NaN, 2), + Row(0.0f/0.0f, 2), + Row(0.0f, 2), + Row(-0.0f, 2))) + + val windowSpec2 = Window.partitionBy(array(struct("f")), struct(array("d"))) checkAnswer( - df.select($"f", count(lit(1)).over(Window.partitionBy(array("f"), struct("d")))), + df.select($"f", count(lit(1)).over(windowSpec2)), Seq( Row(Float.NaN, 2), Row(0.0f/0.0f, 2), @@ -722,12 +739,14 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { // test with df with complicated-type columns. val df2 = Seq( - Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN), - Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru") + (Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN))), + (Array(0.0f, -0.0f), Tuple2(0.0d, Double.NaN), Seq(Tuple2(0.0d, 0.0/0.0))) + ).toDF("arr", "stru", "arrOfStru") + val windowSpec3 = Window.partitionBy("arr", "stru", "arrOfStru") checkAnswer( - df2.select($"arr", $"stru", count(lit(1)).over(Window.partitionBy("arr", "stru"))), + df2.select($"arr", $"stru", $"arrOfStru", count(lit(1)).over(windowSpec3)), Seq( - Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), 2), - Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), 2))) + Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2), + Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 887d9adbce10a..81cc95847a79d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -918,41 +918,55 @@ class JoinSuite extends QueryTest with SharedSQLContext { checkAnswer( sql( """ - |SELECT v1.f, v1.d + |SELECT v1.f, v1.d, v2.f, v2.d |FROM v1 JOIN v2 |ON v1.f = v2.f AND v1.d = v2.d """.stripMargin), Seq( - Row(Float.NaN, Double.NaN), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(-0.0f, -0.0), - Row(-0.0f, -0.0))) + Row(Float.NaN, Double.NaN, Float.NaN, Double.NaN), + Row(0.0f, 0.0, 0.0f, 0.0), + Row(0.0f, 0.0, -0.0f, -0.0), + Row(-0.0f, -0.0, 0.0f, 0.0), + Row(-0.0f, -0.0, -0.0f, -0.0))) // test with complicated join keys. checkAnswer( sql( """ - |SELECT v1.f, v1.d + |SELECT v1.f, v1.d, v2.f, v2.d |FROM v1 JOIN v2 - |ON array(v1.f) = array(v2.f) AND struct(v1.d) = struct(v2.d) + |ON + | array(v1.f) = array(v2.f) AND + | struct(v1.d) = struct(v2.d) AND + | array(struct(v1.f, v1.d)) = array(struct(v2.f, v2.d)) AND + | struct(array(v1.f), array(v1.d)) = struct(array(v2.f), array(v2.d)) """.stripMargin), Seq( - Row(Float.NaN, Double.NaN), - Row(0.0f, 0.0), - Row(0.0f, 0.0), - Row(-0.0f, -0.0), - Row(-0.0f, -0.0))) + Row(Float.NaN, Double.NaN, Float.NaN, Double.NaN), + Row(0.0f, 0.0, 0.0f, 0.0), + Row(0.0f, 0.0, -0.0f, -0.0), + Row(-0.0f, -0.0, 0.0f, 0.0), + Row(-0.0f, -0.0, -0.0f, -0.0))) // test with tables with complicated-type columns. - Seq(Array(-0.0f, 0.0f) -> Tuple2(-0.0d, Double.NaN)).toDF("arr", "stru").createTempView("v3") - Seq(Array(0.0f, -0.0f) -> Tuple2(0.0d, Double.NaN)).toDF("arr", "stru").createTempView("v4") + Seq((Array(-0.0f, 0.0f), Tuple2(-0.0d, Double.NaN), Seq(Tuple2(-0.0d, Double.NaN)))) + .toDF("arr", "stru", "arrOfStru").createTempView("v3") + Seq((Array(0.0f, -0.0f), Tuple2(0.0d, 0.0/0.0), Seq(Tuple2(0.0d, 0.0/0.0)))) + .toDF("arr", "stru", "arrOfStru").createTempView("v4") checkAnswer( - sql("SELECT v3.arr, v3.stru FROM v3 JOIN v4 ON v3.arr = v4.arr AND v3.stru = v4.stru"), - Seq(Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN)))) - checkAnswer( - sql("SELECT v4.arr, v4.stru FROM v3 JOIN v4 ON v3.arr = v4.arr AND v3.stru = v4.stru"), - Seq(Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN)))) + sql( + """ + |SELECT v3.arr, v3.stru, v3.arrOfStru, v4.arr, v4.stru, v4.arrOfStru + |FROM v3 JOIN v4 + |ON v3.arr = v4.arr AND v3.stru = v4.stru AND v3.arrOfStru = v4.arrOfStru + """.stripMargin), + Seq(Row( + Seq(-0.0f, 0.0f), + Row(-0.0d, Double.NaN), + Seq(Row(-0.0d, Double.NaN)), + Seq(0.0f, -0.0f), + Row(0.0d, 0.0/0.0), + Seq(Row(0.0d, 0.0/0.0))))) } } }