Skip to content

Commit

Permalink
[SPARK-47681][FOLLOWUP] Fix schema_of_variant for float inputs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

The current `schema_of_variant` depends on `JsonInferSchema.compatibleType` to find the common type of two types. This function doesn't handle the case of `float x decimal` or `decimal x float` correctly. It doesn't produce the expected result, but consider the two types as incompatible and produces `variant`. This change doesn't affect the JSON schema inference because it never produces `float` beforehand.

### Why are the changes needed?

It is a bug fix and is required to process floats correctly.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

A new unit test that checks all type combinations.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #47058 from chenhao-db/fix_schema_of_variant_float.

Authored-by: Chenhao Li <chenhao.li@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
chenhao-db authored and cloud-fan committed Jun 24, 2024
1 parent 31fa9d8 commit 4663b84
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@ object JsonInferSchema {
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
DoubleType

// This branch is only used by `SchemaOfVariant.mergeSchema` because `JsonInferSchema` never
// produces `FloatType`.
case (FloatType, _: DecimalType) | (_: DecimalType, FloatType) =>
DoubleType

case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant

import java.time.{LocalDateTime, ZoneId, ZoneOffset}

import scala.collection.mutable
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
Expand Down Expand Up @@ -860,4 +861,50 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
StructType.fromDDL("c ARRAY<STRING>,b MAP<STRING, STRING>,a STRUCT<i: INT>"))
check(struct, """{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""")
}

test("schema_of_variant - schema merge") {
val nul = Literal(null, StringType)
val boolean = Literal.default(BooleanType)
val long = Literal.default(LongType)
val string = Literal.default(StringType)
val double = Literal.default(DoubleType)
val date = Literal.default(DateType)
val timestamp = Literal.default(TimestampType)
val timestampNtz = Literal.default(TimestampNTZType)
val float = Literal.default(FloatType)
val binary = Literal.default(BinaryType)
val decimal = Literal(Decimal("123.456"), DecimalType(6, 3))
val array1 = Literal(Array(0L))
val array2 = Literal(Array(0.0))
val struct1 = Literal.default(StructType.fromDDL("a string"))
val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint"))
val inputs = Seq(nul, boolean, long, string, double, date, timestamp, timestampNtz, float,
binary, decimal, array1, array2, struct1, struct2)

val results = mutable.HashMap.empty[(Literal, Literal), String]
for (i <- inputs) {
val inputType = if (i.value == null) "VOID" else i.dataType.sql
results.put((nul, i), inputType)
results.put((i, i), inputType)
}
results.put((long, double), "DOUBLE")
results.put((long, float), "FLOAT")
results.put((long, decimal), "DECIMAL(23,3)")
results.put((double, float), "DOUBLE")
results.put((double, decimal), "DOUBLE")
results.put((date, timestamp), "TIMESTAMP")
results.put((date, timestampNtz), "TIMESTAMP_NTZ")
results.put((timestamp, timestampNtz), "TIMESTAMP")
results.put((float, decimal), "DOUBLE")
results.put((array1, array2), "ARRAY<DOUBLE>")
results.put((struct1, struct2), "STRUCT<a: VARIANT, b: BIGINT>")

for (i1 <- inputs) {
for (i2 <- inputs) {
val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), "VARIANT"))
val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, VariantType)))
checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, s"ARRAY<$expected>")
}
}
}
}

0 comments on commit 4663b84

Please sign in to comment.