diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ecba8b263c412..bbc063c321037 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -822,6 +822,7 @@ object FunctionRegistry { expression[ParseJson]("parse_json"), expressionBuilder("variant_get", VariantGetExpressionBuilder), expressionBuilder("try_variant_get", TryVariantGetExpressionBuilder), + expression[SchemaOfVariant]("schema_of_variant"), // cast expression[Cast]("cast"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 45598b6a66f2d..20f86a32c1a1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, C import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.types.{PhysicalBinaryType, PhysicalIntegerType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, YearMonthIntervalType} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ObjectType, ShortType, StringType, StructType, TimestampNTZType, TimestampType, UserDefinedType, VariantType, YearMonthIntervalType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String, VariantVal} /** * Helper class for Generating [[ExpressionEncoder]]s. @@ -122,7 +122,8 @@ object EncoderUtils { TimestampType -> classOf[PhysicalLongType.InternalType], TimestampNTZType -> classOf[PhysicalLongType.InternalType], BinaryType -> classOf[PhysicalBinaryType.InternalType], - CalendarIntervalType -> classOf[CalendarInterval] + CalendarIntervalType -> classOf[CalendarInterval], + VariantType -> classOf[VariantVal] ) val typeBoxedJavaMapping: Map[DataType, Class[_]] = Map[DataType, Class[_]]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 4681326136c70..2f2b5923fed76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant import scala.util.parsing.combinator.RegexParsers +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -26,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.json.JsonInferSchema import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, VARIANT_GET} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -403,3 +405,87 @@ object VariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(true) ) // scalastyle:on line.size.limit object TryVariantGetExpressionBuilder extends VariantGetExpressionBuilderBase(false) + +@ExpressionDescription( + usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.", + examples = """ + Examples: + > SELECT _FUNC_(parse_json('null')); + VOID + > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]')); + ARRAY> + """, + since = "4.0.0", + group = "variant_funcs" +) +case class SchemaOfVariant(child: Expression) + extends UnaryExpression + with RuntimeReplaceable + with ExpectsInputTypes { + override lazy val replacement: Expression = StaticInvoke( + SchemaOfVariant.getClass, + StringType, + "schemaOfVariant", + Seq(child), + inputTypes, + returnNullable = false) + + override def inputTypes: Seq[AbstractDataType] = Seq(VariantType) + + override def dataType: DataType = StringType + + override def prettyName: String = "schema_of_variant" + + override protected def withNewChildInternal(newChild: Expression): SchemaOfVariant = + copy(child = newChild) +} + +object SchemaOfVariant { + /** The actual implementation of the `SchemaOfVariant` expression. */ + def schemaOfVariant(input: VariantVal): UTF8String = { + val v = new Variant(input.getValue, input.getMetadata) + UTF8String.fromString(schemaOf(v).sql) + } + + /** + * Return the schema of a variant. Struct fields are guaranteed to be sorted alphabetically. + */ + def schemaOf(v: Variant): DataType = v.getType match { + case Type.OBJECT => + val size = v.objectSize() + val fields = new Array[StructField](size) + for (i <- 0 until size) { + val field = v.getFieldAtIndex(i) + fields(i) = StructField(field.key, schemaOf(field.value)) + } + // According to the variant spec, object fields must be sorted alphabetically. So we don't + // have to sort, but just need to validate they are sorted. + for (i <- 1 until size) { + if (fields(i - 1).name >= fields(i).name) { + throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty) + } + } + StructType(fields) + case Type.ARRAY => + var elementType: DataType = NullType + for (i <- 0 until v.arraySize()) { + elementType = mergeSchema(elementType, schemaOf(v.getElementAtIndex(i))) + } + ArrayType(elementType) + case Type.NULL => NullType + case Type.BOOLEAN => BooleanType + case Type.LONG => LongType + case Type.STRING => StringType + case Type.DOUBLE => DoubleType + case Type.DECIMAL => + val d = v.getDecimal + DecimalType(d.precision(), d.scale()) + } + + /** + * Returns the tightest common type for two given data types. Input struct fields are assumed to + * be sorted alphabetically. + */ + def mergeSchema(t1: DataType, t2: DataType): DataType = + JsonInferSchema.compatibleType(t1, t2, VariantType) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index 12c1be7c0de70..7ee522226e3ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -360,8 +360,10 @@ object JsonInferSchema { /** * Returns the most general data type for two given data types. + * When the two types are incompatible, return `defaultDataType` as a fallback result. */ - def compatibleType(t1: DataType, t2: DataType): DataType = { + def compatibleType( + t1: DataType, t2: DataType, defaultDataType: DataType = StringType): DataType = { TypeCoercion.findTightestCommonType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { @@ -399,7 +401,8 @@ object JsonInferSchema { val f2Name = fields2(f2Idx).name val comp = f1Name.compareTo(f2Name) if (comp == 0) { - val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + val dataType = compatibleType( + fields1(f1Idx).dataType, fields2(f2Idx).dataType, defaultDataType) newFields.add(StructField(f1Name, dataType, nullable = true)) f1Idx += 1 f2Idx += 1 @@ -422,21 +425,22 @@ object JsonInferSchema { StructType(newFields.toArray(emptyStructFieldArray)) case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + ArrayType( + compatibleType(elementType1, elementType2, defaultDataType), + containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonType`. Both cases below will be executed only when the given // `DecimalType` is not capable of the given `IntegralType`. case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) + compatibleType(DecimalType.forType(t1), t2, defaultDataType) case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) + compatibleType(t1, DecimalType.forType(t2), defaultDataType) case (TimestampNTZType, TimestampType) | (TimestampType, TimestampNTZType) => TimestampType - // strings and every string is a Json object. - case (_, _) => StringType + case (_, _) => defaultDataType } } } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index c461e3ec09fba..05491034e6c76 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -437,6 +437,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.variant.ParseJson | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct | +| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct | | org.apache.spark.sql.catalyst.expressions.variant.TryVariantGetExpressionBuilder | try_variant_get | SELECT try_variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | | org.apache.spark.sql.catalyst.expressions.variant.VariantGetExpressionBuilder | variant_get | SELECT variant_get(parse_json('{"a": 1}'), '$.a', 'int') | struct | | org.apache.spark.sql.catalyst.expressions.xml.XPathBoolean | xpath_boolean | SELECT xpath_boolean('1','a/b') | struct1, a/b):boolean> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index cf12001fa71bd..d8b1dca21ca67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -81,4 +81,35 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { val expected = new VariantVal(v.getValue, v.getMetadata) checkAnswer(variantDF, Seq(Row(expected))) } + + test("schema_of_variant") { + def check(json: String, expected: String): Unit = { + val df = Seq(json).toDF("j").selectExpr("schema_of_variant(parse_json(j))") + checkAnswer(df, Seq(Row(expected))) + } + + check("null", "VOID") + check("1", "BIGINT") + check("1.0", "DECIMAL(1,0)") + check("1E0", "DOUBLE") + check("true", "BOOLEAN") + check("\"2000-01-01\"", "STRING") + check("""{"a":0}""", "STRUCT") + check("""{"b": {"c": "c"}, "a":["a"]}""", "STRUCT, b: STRUCT>") + check("[]", "ARRAY") + check("[false]", "ARRAY") + check("[null, 1, 1.0]", "ARRAY") + check("[null, 1, 1.1]", "ARRAY") + check("[123456.789, 123.456789]", "ARRAY") + check("[1, 11111111111111111111111111111111111111]", "ARRAY") + check("[1.1, 11111111111111111111111111111111111111]", "ARRAY") + check("[1, \"1\"]", "ARRAY") + check("[{}, true]", "ARRAY") + check("""[{"c": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") + check("""[{"a": ""}, {"a": null}, {"b": 1}]""", "ARRAY>") + check( + """[{"a": 1, "b": null}, {"b": true, "a": 1E0}]""", + "ARRAY>" + ) + } }