diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java index ca2ae80042df7..07e13610aa950 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java @@ -18,12 +18,15 @@ package org.apache.spark.sql.catalyst.expressions.json; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import org.apache.spark.sql.catalyst.expressions.SharedFactory; import org.apache.spark.sql.catalyst.json.CreateJacksonParser; +import org.apache.spark.sql.catalyst.util.GenericArrayData; import org.apache.spark.unsafe.types.UTF8String; public class JsonExpressionUtils { @@ -55,4 +58,32 @@ public static Integer lengthOfJsonArray(UTF8String json) { return null; } } + + public static GenericArrayData jsonObjectKeys(UTF8String json) { + // return null for `NULL` input + if (json == null) { + return null; + } + try (JsonParser jsonParser = + CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) { + // return null if an empty string or any other valid JSON string is encountered + if (jsonParser.nextToken() == null || jsonParser.currentToken() != JsonToken.START_OBJECT) { + return null; + } + // Parse the JSON string to get all the keys of outermost JSON object + List arrayBufferOfKeys = new ArrayList<>(); + + // traverse until the end of input and ensure it returns valid key + while (jsonParser.nextValue() != null && jsonParser.currentName() != null) { + // add current fieldName to the ArrayBuffer + arrayBufferOfKeys.add(UTF8String.fromString(jsonParser.currentName())); + + // skip all the children of inner object or array + jsonParser.skipChildren(); + } + return new GenericArrayData(arrayBufferOfKeys.toArray()); + } catch (IOException e) { + return null; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e1f2b1c1df62a..e01531cc821c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.io._ -import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ @@ -1014,50 +1013,23 @@ case class LengthOfJsonArray(child: Expression) group = "json_funcs", since = "3.1.0" ) -case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback - with ExpectsInputTypes { +case class JsonObjectKeys(child: Expression) + extends UnaryExpression + with ExpectsInputTypes + with RuntimeReplaceable { override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" - override def eval(input: InternalRow): Any = { - val json = child.eval(input).asInstanceOf[UTF8String] - // return null for `NULL` input - if(json == null) { - return null - } - - try { - Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) { - parser => { - // return null if an empty string or any other valid JSON string is encountered - if (parser.nextToken() == null || parser.currentToken() != JsonToken.START_OBJECT) { - return null - } - // Parse the JSON string to get all the keys of outermost JSON object - getJsonKeys(parser, input) - } - } - } catch { - case _: JsonProcessingException | _: IOException => null - } - } - - private def getJsonKeys(parser: JsonParser, input: InternalRow): GenericArrayData = { - val arrayBufferOfKeys = ArrayBuffer.empty[UTF8String] - - // traverse until the end of input and ensure it returns valid key - while(parser.nextValue() != null && parser.currentName() != null) { - // add current fieldName to the ArrayBuffer - arrayBufferOfKeys += UTF8String.fromString(parser.currentName) - - // skip all the children of inner object or array - parser.skipChildren() - } - new GenericArrayData(arrayBufferOfKeys.toArray[Any]) - } + override def replacement: Expression = StaticInvoke( + classOf[JsonExpressionUtils], + dataType, + "jsonObjectKeys", + Seq(child), + inputTypes + ) override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys = copy(child = newChild) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain index 30153bb192e55..8a2ea5336c160 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain @@ -1,2 +1,2 @@ -Project [json_object_keys(g#0) AS json_object_keys(g)#0] +Project [static_invoke(JsonExpressionUtils.jsonObjectKeys(g#0)) AS json_object_keys(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]