From 320fb0ac148ad5da25b9367f94cc1651af202092 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 15 Oct 2024 16:48:24 +0800 Subject: [PATCH] [SPARK-49954][SQL] Codegen Support for SchemaOfJson (by Invoke & RuntimeReplaceable) --- .../sql/catalyst/expressions/ExprUtils.scala | 4 +- .../json/JsonExpressionEvalUtils.scala | 53 +++++++++++++++++++ .../expressions/jsonExpressions.scala | 38 ++++++------- .../function_schema_of_json.explain | 2 +- ...nction_schema_of_json_with_options.explain | 2 +- 5 files changed, 74 insertions(+), 25 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 08cb03edb78b6..38b927f5bbf38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseA import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String -object ExprUtils extends QueryErrorsBase { +object ExprUtils extends EvalHelper with QueryErrorsBase { def evalTypeExpr(exp: Expression): DataType = { if (exp.foldable) { - exp.eval() match { + prepareForEval(exp).eval() match { case s: UTF8String if s != null => val dataType = DataType.parseTypeWithFallback( s.toString, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala new file mode 100644 index 0000000000000..65c95c8240f4f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.json + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, StructType} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +object JsonExpressionEvalUtils { + + def schemaOfJson( + jsonFactory: JsonFactory, + jsonOptions: JSONOptions, + jsonInferSchema: JsonInferSchema, + json: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parser.nextToken() + // To match with schema inference from JSON datasource. + jsonInferSchema.inferField(parser) match { + case st: StructType => + jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) + case at: ArrayType if at.elementType.isInstanceOf[StructType] => + jsonInferSchema + .canonicalizeType(at.elementType, jsonOptions) + .map(ArrayType(_, containsNull = at.containsNull)) + .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) + case other: DataType => + jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( + SQLConf.get.defaultStringType) + } + } + + UTF8String.fromString(dt.sql) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e01531cc821c9..3118fe9a2eb44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json._ @@ -878,7 +878,9 @@ case class StructsToJson( case class SchemaOfJson( child: Expression, options: Map[String, String]) - extends UnaryExpression with CodegenFallback with QueryErrorsBase { + extends UnaryExpression + with RuntimeReplaceable + with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -919,26 +921,20 @@ case class SchemaOfJson( } } - override def eval(v: InternalRow): Any = { - val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => - parser.nextToken() - // To match with schema inference from JSON datasource. - jsonInferSchema.inferField(parser) match { - case st: StructType => - jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) - case at: ArrayType if at.elementType.isInstanceOf[StructType] => - jsonInferSchema - .canonicalizeType(at.elementType, jsonOptions) - .map(ArrayType(_, containsNull = at.containsNull)) - .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) - case other: DataType => - jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( - SQLConf.get.defaultStringType) - } - } + @transient private lazy val jsonFactoryObjectType = ObjectType(classOf[JsonFactory]) + @transient private lazy val jsonOptionsObjectType = ObjectType(classOf[JSONOptions]) + @transient private lazy val jsonInferSchemaObjectType = ObjectType(classOf[JsonInferSchema]) - UTF8String.fromString(dt.sql) - } + override def replacement: Expression = StaticInvoke( + JsonExpressionEvalUtils.getClass, + dataType, + "schemaOfJson", + Seq(Literal(jsonFactory, jsonFactoryObjectType), + Literal(jsonOptions, jsonOptionsObjectType), + Literal(jsonInferSchema, jsonInferSchemaObjectType), + child), + Seq(jsonFactoryObjectType, jsonOptionsObjectType, jsonInferSchemaObjectType, child.dataType) + ) override def prettyName: String = "schema_of_json" diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain index 8ec799bc58084..b400aeeca5af2 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}]) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionEvalUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain index 13867949177a4..b400aeeca5af2 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}], (allowNumericLeadingZeros,true)) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionEvalUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]