Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble/Float #23043

Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{Block, CodegenContext, CodeGenerator, ExprCode, FalseLiteral, JavaCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -56,17 +56,32 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
val javaType = JavaCode.javaType(dataType)
val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
if (nullable) {
ev.copy(code =
var codeBlock =
Copy link
Member

@kiszk kiszk Nov 15, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to use val instead of var.

code"""
|boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
|$javaType ${ev.value} = ${ev.isNull} ?
| ${CodeGenerator.defaultValue(dataType)} : ($value);
""".stripMargin)
""".stripMargin
codeBlock = codeBlock + genReplaceMinusZeroWithZeroCode(javaType.codeString, ev.value)
ev.copy(code = codeBlock)
} else {
ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral)
var codeBlock = code"$javaType ${ev.value} = $value;"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

codeBlock = codeBlock + genReplaceMinusZeroWithZeroCode(javaType.codeString, ev.value)
ev.copy(code = codeBlock, isNull = FalseLiteral)
}
}
}

private def genReplaceMinusZeroWithZeroCode(javaType: String, value: String): Block = {
val code = s"\nif ($value == -0.0%c) $value = 0.0%c;"
var formattedCode = ""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

javaType match {
case "double" | "java.lang.Double" => formattedCode = code.format('d', 'd')
case "float" | "java.lang.Float" => formattedCode = code.format('f', 'f')
case _ =>
}
code"$formattedCode"
}
}

object BindReferences extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,22 @@ class GenerateUnsafeProjectionSuite extends SparkFunSuite {
assert(!result3.getStruct(0, 2).isNullAt(0))
assert(!result3.getStruct(0, 3).isNullAt(0))
}

test("SPARK-26021: Test replacing -0.0 with 0.0") {
val exprs =
BoundReference(0, DoubleType, nullable = false) ::
BoundReference(1, DoubleType, nullable = true) ::
BoundReference(2, FloatType, nullable = false) ::
BoundReference(3, FloatType, nullable = true) ::
Nil
val projection = GenerateUnsafeProjection.generate(exprs)
val result = projection.apply(InternalRow(-0.0d, Double.box(-0.0d), -0.0f, Float.box(-0.0f)))
// using compare since 0.0 == -0.0 is true
assert(result.getDouble(0).compareTo(0.0d) == 0)
assert(result.getDouble(1).compareTo(0.0d) == 0)
assert(result.getFloat(2).compareTo(0.0f) == 0)
assert(result.getFloat(3).compareTo(0.0f) == 0)
}
}

object AlwaysNull extends InternalRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,4 +723,32 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
"type: GroupBy]"))
}

test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") {
val colName = "i"
def groupByCollect(df: DataFrame): Array[Row] = {
df.groupBy(colName).count().collect()
}
def assertResult[T](result: Array[Row], zero: T)(implicit ordering: Ordering[T]): Unit = {
assert(result.length == 1)
// using compare since 0.0 == -0.0 is true
assert(ordering.compare(result(0).getAs[T](0), zero) == 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of checking the result, I prefer the code snippet in the JIRA ticket, which is more obvious about where is the problem.

Let's run a group-by query, with both 0.0 and -0.0 in the input. Then we check the number of result rows, as ideally 0.0 and -0.0 is same, so we should only have one group(one result row).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow, below this I'm constructing Seqs with 0 and -0 like in the JIRA and in the assertResult helper I'm checking that there's only 1 line like you said.
Do you mean the check that the key is indeed 0.0 and not -0.0 is redundant?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry I misread the code.

assert(result(0).getLong(1) == 3)
}

spark.conf.set("spark.sql.codegen.wholeStage", "false")
val doubles =
groupByCollect(Seq(0.0d, 0.0d, -0.0d).toDF(colName))
val doublesBoxed =
groupByCollect(Seq(Double.box(0.0d), Double.box(0.0d), Double.box(-0.0d)).toDF(colName))
val floats =
groupByCollect(Seq(0.0f, -0.0f, 0.0f).toDF(colName))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we have to turn off whole-stage-codegen?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like leftovers from a different solution. Also there's no need to test the boxed version now that it's not in the codegen. I'll simplify the test.

val floatsBoxed =
groupByCollect(Seq(Float.box(0.0f), Float.box(-0.0f), Float.box(0.0f)).toDF(colName))

assertResult(doubles, 0.0d)
assertResult(doublesBoxed, 0.0d)
assertResult(floats, 0.0f)
assertResult(floatsBoxed, 0.0f)
}
}