diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 5f94af5ffe636..43738204c6704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -56,10 +56,6 @@ import org.apache.spark.sql.types._ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan match { - // A subquery will be rewritten into join later, and will go through this rule - // eventually. Here we skip subquery, as we only need to run this rule once. - case _: Subquery => plan - case _ => plan transform { case w: Window if w.partitionSpec.exists(p => needNormalize(p)) => // Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a23e5831f5887..093f2dbd1e426 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3449,6 +3449,24 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"), Seq(Row(Short.MinValue.toLong * -1))) } + + test("normalize special floating numbers in subquery") { + withTempView("v1", "v2", "v3") { + Seq(-0.0).toDF("d").createTempView("v1") + Seq(0.0).toDF("d").createTempView("v2") + spark.range(2).createTempView("v3") + + // non-correlated subquery + checkAnswer(sql("SELECT (SELECT v1.d FROM v1 JOIN v2 ON v1.d = v2.d)"), Row(-0.0)) + // correlated subquery + checkAnswer( + sql( + """ + |SELECT id FROM v3 WHERE EXISTS + | (SELECT v1.d FROM v1 JOIN v2 ON v1.d = v2.d WHERE id > 0) + |""".stripMargin), Row(1)) + } + } } case class Foo(bar: Option[String])