Skip to content

Commit

Permalink
Fix InSet for binary, and struct and array with null.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Nov 29, 2018
1 parent 8bfea86 commit 277c48f
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,29 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}

@transient lazy val set: Set[Any] = child.dataType match {
case _: AtomicType => hset
case t: AtomicType if !t.isInstanceOf[BinaryType] => hset
case _: NullType => hset
case _ =>
val ord = TypeUtils.getInterpretedOrdering(child.dataType)
val ordering = if (hasNull) {
new Ordering[Any] {
override def compare(x: Any, y: Any): Int = {
if (x == null && y == null) {
0
} else if (x == null) {
-1
} else if (y == null) {
1
} else {
ord.compare(x, y)
}
}
}
} else {
ord
}
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
TreeSet.empty(ordering) ++ hset
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(nl, nS), null)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.foreach { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
Expand All @@ -293,6 +293,54 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("INSET: binary") {
val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte)
val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null
val onetwo = Literal(Array(1.toByte, 2.toByte))
val three = Literal(Array(3.toByte))
val threefour = Literal(Array(3.toByte, 4.toByte))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

test("INSET: struct") {
val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value
val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null
val oneA = Literal.create((1, "a"))
val twoB = Literal.create((2, "b"))
val twoC = Literal.create((2, "c"))
val nl = Literal(null, oneA.dataType)
checkEvaluation(InSet(oneA, hS), true)
checkEvaluation(InSet(twoB, hS), true)
checkEvaluation(InSet(twoB, nS), true)
checkEvaluation(InSet(twoC, hS), false)
checkEvaluation(InSet(twoC, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

test("INSET: array") {
val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value
val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null
val onetwo = Literal.create(Seq(1, 2))
val three = Literal.create(Seq(3))
val threefour = Literal.create(Seq(3, 4))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

private case class MyStruct(a: Long, b: String)
private case class MyStruct2(a: MyStruct, b: Array[Int])
private val udt = new ExamplePointUDT
Expand Down

0 comments on commit 277c48f

Please sign in to comment.