Skip to content

Commit

Permalink
[SPARK-26211][SQL] Fix InSet for binary, and struct and array with null.
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently `InSet` doesn't work properly for binary type, or struct and array type with null value in the set.
Because, as for binary type, the `HashSet` doesn't work properly for `Array[Byte]`, and as for struct and array type with null value in the set, the `ordering` will throw a `NPE`.

## How was this patch tested?

Added a few tests.

Closes apache#23176 from ueshin/issues/SPARK-26211/inset.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
ueshin authored and cloud-fan committed Nov 29, 2018
1 parent 7a83d71 commit b9b68a6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,31 +367,26 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}

@transient lazy val set: Set[Any] = child.dataType match {
case _: AtomicType => hset
case t: AtomicType if !t.isInstanceOf[BinaryType] => hset
case _: NullType => hset
case _ =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ (hset - null)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setTerm = ctx.addReferenceObj("set", set)
val childGen = child.genCode(ctx)
val setIsNull = if (hasNull) {
s"${ev.isNull} = !${ev.value};"
} else {
""
}
ev.copy(code =
code"""
|${childGen.code}
|${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull};
|${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false;
|if (!${ev.isNull}) {
| ${ev.value} = $setTerm.contains(${childGen.value});
| $setIsNull
|}
""".stripMargin)
nullSafeCodeGen(ctx, ev, c => {
val setTerm = ctx.addReferenceObj("set", set)
val setIsNull = if (hasNull) {
s"${ev.isNull} = !${ev.value};"
} else {
""
}
s"""
|${ev.value} = $setTerm.contains($c);
|$setIsNull
""".stripMargin
})
}

override def sql: String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(nl, nS), null)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.foreach { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
Expand All @@ -293,6 +293,54 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("INSET: binary") {
val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte)
val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null
val onetwo = Literal(Array(1.toByte, 2.toByte))
val three = Literal(Array(3.toByte))
val threefour = Literal(Array(3.toByte, 4.toByte))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

test("INSET: struct") {
val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value
val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null
val oneA = Literal.create((1, "a"))
val twoB = Literal.create((2, "b"))
val twoC = Literal.create((2, "c"))
val nl = Literal(null, oneA.dataType)
checkEvaluation(InSet(oneA, hS), true)
checkEvaluation(InSet(twoB, hS), true)
checkEvaluation(InSet(twoB, nS), true)
checkEvaluation(InSet(twoC, hS), false)
checkEvaluation(InSet(twoC, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

test("INSET: array") {
val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value
val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null
val onetwo = Literal.create(Seq(1, 2))
val three = Literal.create(Seq(3))
val threefour = Literal.create(Seq(3, 4))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

private case class MyStruct(a: Long, b: String)
private case class MyStruct2(a: MyStruct, b: Array[Int])
private val udt = new ExamplePointUDT
Expand Down

0 comments on commit b9b68a6

Please sign in to comment.