Skip to content

Commit

Permalink
[SPARK-47681][SQL] Add schema_of_variant expression
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds a new `SchemaOfVariant` expression. It returns schema in the SQL format of a variant.

Usage examples:

```
> SELECT schema_of_variant(parse_json('null'));
 VOID
> SELECT schema_of_variant(parse_json('[{"b":true,"a":0}]'));
 ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>>
```

### Why are the changes needed?

This expression can help the user explore the content of variant values.

### Does this PR introduce _any_ user-facing change?

Yes.  A new SQL expression is added.

### How was this patch tested?

Unit tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #45806 from chenhao-db/variant_schema.

Authored-by: Chenhao Li <chenhao.li@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
chenhao-db authored and cloud-fan committed Apr 8, 2024
1 parent 60806c6 commit 134a139
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@ object FunctionRegistry {
expression[ParseJson]("parse_json"),
expressionBuilder("variant_get", VariantGetExpressionBuilder),
expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder),
expression[SchemaOfVariant]("schema_of_variant"),

// cast
expression[Cast]("cast"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, C
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType}
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal}

/**
* Helper class for Generating [[ExpressionEncoder]]s.
Expand Down Expand Up @@ -122,7 +122,8 @@ object EncoderUtils {
TimestampType -> classOf[PhysicalLongType.InternalType],
TimestampNTZType -> classOf[PhysicalLongType.InternalType],
BinaryType -> classOf[PhysicalBinaryType.InternalType],
CalendarIntervalType -> classOf[CalendarInterval]
CalendarIntervalType -> classOf[CalendarInterval],
VariantType -> classOf[VariantVal]
)

val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package org.apache.spark.sql.catalyst.expressions.variant

import scala.util.parsing.combinator.RegexParsers

import org.apache.spark.SparkRuntimeException
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.json.JsonInferSchema
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
Expand Down Expand Up @@ -403,3 +405,87 @@ object VariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(true)
)
// scalastyle:on line.size.limit
object TryVariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(false)

@ExpressionDescription(
usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.",
examples = """
Examples:
> SELECT _FUNC_(parse_json('null'));
VOID
> SELECT _FUNC_(parse_json('[{"b":true,"a":0}]'));
ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>>
""",
since = "4.0.0",
group = "variant_funcs"
)
case class SchemaOfVariant(child: Expression)
extends UnaryExpression
with RuntimeReplaceable
with ExpectsInputTypes {
override lazy val replacement: Expression = StaticInvoke(
SchemaOfVariant.getClass,
StringType,
"schemaOfVariant",
Seq(child),
inputTypes,
returnNullable = false)

override def inputTypes: Seq[AbstractDataType] = Seq(VariantType)

override def dataType: DataType = StringType

override def prettyName: String = "schema_of_variant"

override protected def withNewChildInternal(newChild: Expression): SchemaOfVariant =
copy(child = newChild)
}

object SchemaOfVariant {
/** The actual implementation of the `SchemaOfVariant` expression. */
def schemaOfVariant(input: VariantVal): UTF8String = {
val v = new Variant(input.getValue, input.getMetadata)
UTF8String.fromString(schemaOf(v).sql)
}

/**
* Return the schema of a variant. Struct fields are guaranteed to be sorted alphabetically.
*/
def schemaOf(v: Variant): DataType = v.getType match {
case Type.OBJECT =>
val size = v.objectSize()
val fields = new Array[StructField](size)
for (i <- 0 until size) {
val field = v.getFieldAtIndex(i)
fields(i) = StructField(field.key, schemaOf(field.value))
}
// According to the variant spec, object fields must be sorted alphabetically. So we don't
// have to sort, but just need to validate they are sorted.
for (i <- 1 until size) {
if (fields(i - 1).name >= fields(i).name) {
throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty)
}
}
StructType(fields)
case Type.ARRAY =>
var elementType: DataType = NullType
for (i <- 0 until v.arraySize()) {
elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i)))
}
ArrayType(elementType)
case Type.NULL => NullType
case Type.BOOLEAN => BooleanType
case Type.LONG => LongType
case Type.STRING => StringType
case Type.DOUBLE => DoubleType
case Type.DECIMAL =>
val d = v.getDecimal
DecimalType(d.precision(), d.scale())
}

/**
* Returns the tightest common type for two given data types. Input struct fields are assumed to
* be sorted alphabetically.
*/
def mergeSchema(t1: DataType, t2: DataType): DataType =
JsonInferSchema.compatibleType(t1, t2, VariantType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,10 @@ object JsonInferSchema {

/**
* Returns the most general data type for two given data types.
* When the two types are incompatible, return `defaultDataType` as a fallback result.
*/
def compatibleType(t1: DataType, t2: DataType): DataType = {
def compatibleType(
t1: DataType, t2: DataType, defaultDataType: DataType = StringType): DataType = {
TypeCoercion.findTightestCommonType(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
Expand Down Expand Up @@ -399,7 +401,8 @@ object JsonInferSchema {
val f2Name = fields2(f2Idx).name
val comp = f1Name.compareTo(f2Name)
if (comp == 0) {
val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType)
val dataType = compatibleType(
fields1(f1Idx).dataType, fields2(f2Idx).dataType, defaultDataType)
newFields.add(StructField(f1Name, dataType, nullable = true))
f1Idx += 1
f2Idx += 1
Expand All @@ -422,21 +425,22 @@ object JsonInferSchema {
StructType(newFields.toArray(emptyStructFieldArray))

case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)
ArrayType(
compatibleType(elementType1, elementType2, defaultDataType),
containsNull1 || containsNull2)

// The case that given `DecimalType` is capable of given `IntegralType` is handled in
// `findTightestCommonType`. Both cases below will be executed only when the given
// `DecimalType` is not capable of the given `IntegralType`.
case (t1: IntegralType, t2: DecimalType) =>
compatibleType(DecimalType.forType(t1), t2)
compatibleType(DecimalType.forType(t1), t2, defaultDataType)
case (t1: DecimalType, t2: IntegralType) =>
compatibleType(t1, DecimalType.forType(t2))
compatibleType(t1, DecimalType.forType(t2), defaultDataType)

case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) =>
TimestampType

// strings and every string is a Json object.
case (_, _) => StringType
case (_, _) => defaultDataType
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<var_samp(col):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct<variance(col):double> |
| org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct<parse_json({"a":1,"b":0.8}):variant> |
| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct<schema_of_variant(parse_json(null)):string> |
| org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<try_variant_get(parse_json({"a": 1}), $.a):int> |
| org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct<variant_get(parse_json({"a": 1}), $.a):int> |
| org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('<a><b>1</b></a>','a/b') | struct<xpath_boolean(<a><b>1</b></a>, a/b):boolean> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,35 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
val expected = new VariantVal(v.getValue, v.getMetadata)
checkAnswer(variantDF, Seq(Row(expected)))
}

test("schema_of_variant") {
def check(json: String, expected: String): Unit = {
val df = Seq(json).toDF("j").selectExpr("schema_of_variant(parse_json(j))")
checkAnswer(df, Seq(Row(expected)))
}

check("null", "VOID")
check("1", "BIGINT")
check("1.0", "DECIMAL(1,0)")
check("1E0", "DOUBLE")
check("true", "BOOLEAN")
check("\"2000-01-01\"", "STRING")
check("""{"a":0}""", "STRUCT<a: BIGINT>")
check("""{"b": {"c": "c"}, "a":["a"]}""", "STRUCT<a: ARRAY<STRING>, b: STRUCT<c: STRING>>")
check("[]", "ARRAY<VOID>")
check("[false]", "ARRAY<BOOLEAN>")
check("[null, 1, 1.0]", "ARRAY<DECIMAL(20,0)>")
check("[null, 1, 1.1]", "ARRAY<DECIMAL(21,1)>")
check("[123456.789, 123.456789]", "ARRAY<DECIMAL(12,6)>")
check("[1, 11111111111111111111111111111111111111]", "ARRAY<DECIMAL(38,0)>")
check("[1.1, 11111111111111111111111111111111111111]", "ARRAY<DOUBLE>")
check("[1, \"1\"]", "ARRAY<VARIANT>")
check("[{}, true]", "ARRAY<VARIANT>")
check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY<STRUCT<a: VOID, b: BIGINT, c: STRING>>")
check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY<STRUCT<a: STRING, b: BIGINT>>")
check(
"""[{"a": 1, "b": null}, {"b": true, "a": 1E0}]""",
"ARRAY<STRUCT<a: DOUBLE, b: BOOLEAN>>"
)
}
}

0 comments on commit 134a139

Please sign in to comment.