From 4663b84eed1954b5df38e522e14eb2bee823f294 Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Mon, 24 Jun 2024 16:04:24 +0800 Subject: [PATCH] [SPARK-47681][FOLLOWUP] Fix schema_of_variant for float inputs ### What changes were proposed in this pull request? The current `schema_of_variant` depends on `JsonInferSchema.compatibleType` to find the common type of two types. This function doesn't handle the case of `float x decimal` or `decimal x float` correctly. It doesn't produce the expected result, but consider the two types as incompatible and produces `variant`. This change doesn't affect the JSON schema inference because it never produces `float` beforehand. ### Why are the changes needed? It is a bug fix and is required to process floats correctly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? A new unit test that checks all type combinations. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47058 from chenhao-db/fix_schema_of_variant_float. Authored-by: Chenhao Li Signed-off-by: Wenchen Fan --- .../sql/catalyst/json/JsonInferSchema.scala | 5 ++ .../variant/VariantExpressionSuite.scala | 47 +++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 7ee522226e3ec..d982e1f19da0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -372,6 +372,11 @@ object JsonInferSchema { case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => DoubleType + // This branch is only used by `SchemaOfVariant.mergeSchema` because `JsonInferSchema` never + // produces `FloatType`. + case (FloatType, _: DecimalType) | (_: DecimalType, FloatType) => + DoubleType + case (t1: DecimalType, t2: DecimalType) => val scale = math.max(t1.scale, t2.scale) val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 73abf8074e8c2..a758fa84f6fca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant import java.time.{LocalDateTime, ZoneId, ZoneOffset} +import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag import org.apache.spark.{SparkFunSuite, SparkRuntimeException} @@ -860,4 +861,50 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { StructType.fromDDL("c ARRAY,b MAP,a STRUCT")) check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""") } + + test("schema_of_variant - schema merge") { + val nul = Literal(null, StringType) + val boolean = Literal.default(BooleanType) + val long = Literal.default(LongType) + val string = Literal.default(StringType) + val double = Literal.default(DoubleType) + val date = Literal.default(DateType) + val timestamp = Literal.default(TimestampType) + val timestampNtz = Literal.default(TimestampNTZType) + val float = Literal.default(FloatType) + val binary = Literal.default(BinaryType) + val decimal = Literal(Decimal("123.456"), DecimalType(6, 3)) + val array1 = Literal(Array(0L)) + val array2 = Literal(Array(0.0)) + val struct1 = Literal.default(StructType.fromDDL("a string")) + val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint")) + val inputs = Seq(nul, boolean, long, string, double, date, timestamp, timestampNtz, float, + binary, decimal, array1, array2, struct1, struct2) + + val results = mutable.HashMap.empty[(Literal, Literal), String] + for (i <- inputs) { + val inputType = if (i.value == null) "VOID" else i.dataType.sql + results.put((nul, i), inputType) + results.put((i, i), inputType) + } + results.put((long, double), "DOUBLE") + results.put((long, float), "FLOAT") + results.put((long, decimal), "DECIMAL(23,3)") + results.put((double, float), "DOUBLE") + results.put((double, decimal), "DOUBLE") + results.put((date, timestamp), "TIMESTAMP") + results.put((date, timestampNtz), "TIMESTAMP_NTZ") + results.put((timestamp, timestampNtz), "TIMESTAMP") + results.put((float, decimal), "DOUBLE") + results.put((array1, array2), "ARRAY") + results.put((struct1, struct2), "STRUCT") + + for (i1 <- inputs) { + for (i2 <- inputs) { + val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), "VARIANT")) + val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, VariantType))) + checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, s"ARRAY<$expected>") + } + } + } }