Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
panbingkun committed Oct 12, 2024
2 parents 8607bed + 62ade5f commit 4691be2
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 50 deletions.
4 changes: 4 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,10 @@
reduce the cost of migration in subsequent versions.
-->
<arg>-Wconf:cat=deprecation&amp;msg=it will become a keyword in Scala 3:e</arg>
<!--
SPARK-49937 ban call the method `SparkThrowable#getErrorClass`
-->
<arg>-Wconf:cat=deprecation&amp;msg=method getErrorClass in trait SparkThrowable is deprecated:e</arg>
</args>
<jvmArgs>
<jvmArg>-Xss128m</jvmArg>
Expand Down
4 changes: 3 additions & 1 deletion project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,9 @@ object SparkBuild extends PomBuild {
// reduce the cost of migration in subsequent versions.
"-Wconf:cat=deprecation&msg=it will become a keyword in Scala 3:e",
// SPARK-46938 to prevent enum scan on pmml-model, under spark-mllib module.
"-Wconf:cat=other&site=org.dmg.pmml.*:w"
"-Wconf:cat=other&site=org.dmg.pmml.*:w",
// SPARK-49937 ban call the method `SparkThrowable#getErrorClass`
"-Wconf:cat=deprecation&msg=method getErrorClass in trait SparkThrowable is deprecated:e"
)
}
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.json;

import java.io.IOException;
Expand All @@ -30,6 +31,34 @@

public class JsonExpressionUtils {

public static Integer lengthOfJsonArray(UTF8String json) {
// return null for null input
if (json == null) {
return null;
}
try (JsonParser jsonParser =
CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) {
if (jsonParser.nextToken() == null) {
return null;
}
// Only JSON array are supported for this function.
if (jsonParser.currentToken() != JsonToken.START_ARRAY) {
return null;
}
// Parse the array to compute its length.
int length = 0;
// Keep traversing until the end of JSON array
while (jsonParser.nextToken() != JsonToken.END_ARRAY) {
length += 1;
// skip all the child of inner object or array
jsonParser.skipChildren();
}
return length;
} catch (IOException e) {
return null;
}
}

public static GenericArrayData jsonObjectKeys(UTF8String json) {
// return null for `NULL` input
if (json == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
Expand Down Expand Up @@ -5330,15 +5331,12 @@ case class ArrayCompact(child: Expression)
child.dataType.asInstanceOf[ArrayType].elementType, true)
lazy val lambda = LambdaFunction(isNotNull(lv), Seq(lv))

override lazy val replacement: Expression = ArrayFilter(child, lambda)
override lazy val replacement: Expression = KnownNotContainsNull(ArrayFilter(child, lambda))

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)

override def prettyName: String = "array_compact"

override def dataType: ArrayType =
child.dataType.asInstanceOf[ArrayType].copy(containsNull = false)

override protected def withNewChildInternal(newChild: Expression): ArrayCompact =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType}

trait TaggingExpression extends UnaryExpression {
override def nullable: Boolean = child.nullable
Expand Down Expand Up @@ -52,6 +52,17 @@ case class KnownNotNull(child: Expression) extends TaggingExpression {
copy(child = newChild)
}

case class KnownNotContainsNull(child: Expression) extends TaggingExpression {
override def dataType: DataType =
child.dataType.asInstanceOf[ArrayType].copy(containsNull = false)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
child.genCode(ctx)

override protected def withNewChildInternal(newChild: Expression): KnownNotContainsNull =
copy(child = newChild)
}

case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression {
override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized =
copy(child = newChild)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -968,54 +968,26 @@ case class SchemaOfJson(
group = "json_funcs",
since = "3.1.0"
)
case class LengthOfJsonArray(child: Expression) extends UnaryExpression
with CodegenFallback with ExpectsInputTypes {
case class LengthOfJsonArray(child: Expression)
extends UnaryExpression
with ExpectsInputTypes
with RuntimeReplaceable {

override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity)
override def dataType: DataType = IntegerType
override def nullable: Boolean = true
override def prettyName: String = "json_array_length"

override def eval(input: InternalRow): Any = {
val json = child.eval(input).asInstanceOf[UTF8String]
// return null for null input
if (json == null) {
return null
}

try {
Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) {
parser => {
// return null if null array is encountered.
if (parser.nextToken() == null) {
return null
}
// Parse the array to compute its length.
parseCounter(parser, input)
}
}
} catch {
case _: JsonProcessingException | _: IOException => null
}
}

private def parseCounter(parser: JsonParser, input: InternalRow): Any = {
var length = 0
// Only JSON array are supported for this function.
if (parser.currentToken != JsonToken.START_ARRAY) {
return null
}
// Keep traversing until the end of JSON array
while(parser.nextToken() != JsonToken.END_ARRAY) {
length += 1
// skip all the child of inner object or array
parser.skipChildren()
}
length
}

override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray =
copy(child = newChild)

override def replacement: Expression = StaticInvoke(
classOf[JsonExpressionUtils],
dataType,
"lengthOfJsonArray",
Seq(child),
inputTypes
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, Remainder}
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StructField, StructType}

/**
* A dummy optimizer rule for testing that decrements integer literals until 0.
Expand Down Expand Up @@ -313,4 +313,25 @@ class OptimizerSuite extends PlanTest {
assert(message1.contains("not a valid aggregate expression"))
}
}

test("SPARK-49924: Keep containsNull after ArrayCompact replacement") {
val optimizer = new SimpleTestOptimizer() {
override def defaultBatches: Seq[Batch] =
Batch("test", fixedPoint,
ReplaceExpressions) :: Nil
}

val array1 = ArrayCompact(CreateArray(Literal(1) :: Literal.apply(null) :: Nil, false))
val plan1 = Project(Alias(array1, "arr")() :: Nil, OneRowRelation()).analyze
val optimized1 = optimizer.execute(plan1)
assert(optimized1.schema ===
StructType(StructField("arr", ArrayType(IntegerType, false), false) :: Nil))

val struct = CreateStruct(Literal(1) :: Literal(2) :: Nil)
val array2 = ArrayCompact(CreateArray(struct :: Literal.apply(null) :: Nil, false))
val plan2 = Project(Alias(MapFromEntries(array2), "map")() :: Nil, OneRowRelation()).analyze
val optimized2 = optimizer.execute(plan2)
assert(optimized2.schema ===
StructType(StructField("map", MapType(IntegerType, IntegerType, false), false) :: Nil))
}
}
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false)) AS array_compact(e)#0]
Project [knownnotcontainsnull(filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false))) AS array_compact(e)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [json_array_length(g#0) AS json_array_length(g)#0]
Project [static_invoke(JsonExpressionUtils.lengthOfJsonArray(g#0)) AS json_array_length(g)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 comments on commit 4691be2

Please sign in to comment.