From e96ba8430c693c3bcc9f6797a4779c8e9fadaaba Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Thu, 29 Nov 2018 22:37:02 +0800 Subject: [PATCH] [SPARK-26211][SQL] Fix InSet for binary, and struct and array with null. Currently `InSet` doesn't work properly for binary type, or struct and array type with null value in the set. Because, as for binary type, the `HashSet` doesn't work properly for `Array[Byte]`, and as for struct and array type with null value in the set, the `ordering` will throw a `NPE`. Added a few tests. Closes #23176 from ueshin/issues/SPARK-26211/inset. Authored-by: Takuya UESHIN Signed-off-by: Wenchen Fan (cherry picked from commit b9b68a6dc7d0f735163e980392ea957f2d589923) Signed-off-by: Wenchen Fan --- .../sql/catalyst/expressions/predicates.scala | 31 +++++------- .../catalyst/expressions/PredicateSuite.scala | 50 ++++++++++++++++++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a29b1360541cc..8c42eee81509c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -337,31 +337,26 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } @transient lazy val set: Set[Any] = child.dataType match { - case _: AtomicType => hset + case t: AtomicType if !t.isInstanceOf[BinaryType] => hset case _: NullType => hset case _ => // for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows - TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset + TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset - null) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val setTerm = ctx.addReferenceObj("set", set) - val childGen = child.genCode(ctx) - val setIsNull = if (hasNull) { - s"${ev.isNull} = !${ev.value};" - } else { - "" - } - ev.copy(code = + nullSafeCodeGen(ctx, ev, c => { + val setTerm = ctx.addReferenceObj("set", set) + val setIsNull = if (hasNull) { + s"${ev.isNull} = !${ev.value};" + } else { + "" + } s""" - |${childGen.code} - |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; - |${ctx.JAVA_BOOLEAN} ${ev.value} = false; - |if (!${ev.isNull}) { - | ${ev.value} = $setTerm.contains(${childGen.value}); - | $setIsNull - |} - """.stripMargin) + |${ev.value} = $setTerm.contains($c); + |$setIsNull + """.stripMargin + }) } override def sql: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1bfd180ae4393..861fddc28087d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -268,7 +268,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InSet(nl, nS), null) val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, - LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) + LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) primitiveTypes.foreach { t => val dataGen = RandomDataGenerator.forType(t, nullable = true).get val inputData = Seq.fill(10) { @@ -293,6 +293,54 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("INSET: binary") { + val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null + val onetwo = Literal(Array(1.toByte, 2.toByte)) + val three = Literal(Array(3.toByte)) + val threefour = Literal(Array(3.toByte, 4.toByte)) + val nl = Literal(null, onetwo.dataType) + checkEvaluation(InSet(onetwo, hS), true) + checkEvaluation(InSet(three, hS), true) + checkEvaluation(InSet(three, nS), true) + checkEvaluation(InSet(threefour, hS), false) + checkEvaluation(InSet(threefour, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + } + + test("INSET: struct") { + val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null + val oneA = Literal.create((1, "a")) + val twoB = Literal.create((2, "b")) + val twoC = Literal.create((2, "c")) + val nl = Literal(null, oneA.dataType) + checkEvaluation(InSet(oneA, hS), true) + checkEvaluation(InSet(twoB, hS), true) + checkEvaluation(InSet(twoB, nS), true) + checkEvaluation(InSet(twoC, hS), false) + checkEvaluation(InSet(twoC, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + } + + test("INSET: array") { + val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null + val onetwo = Literal.create(Seq(1, 2)) + val three = Literal.create(Seq(3)) + val threefour = Literal.create(Seq(3, 4)) + val nl = Literal(null, onetwo.dataType) + checkEvaluation(InSet(onetwo, hS), true) + checkEvaluation(InSet(three, hS), true) + checkEvaluation(InSet(three, nS), true) + checkEvaluation(InSet(threefour, hS), false) + checkEvaluation(InSet(threefour, nS), null) + checkEvaluation(InSet(nl, hS), null) + checkEvaluation(InSet(nl, nS), null) + } + private case class MyStruct(a: Long, b: String) private case class MyStruct2(a: MyStruct, b: Array[Int]) private val udt = new ExamplePointUDT