Skip to content

Commit

Permalink
[SPARK-24659][SQL] GenericArrayData.equals should respect element typ…
Browse files Browse the repository at this point in the history
…e differences

## What changes were proposed in this pull request?

Fix `GenericArrayData.equals`, so that it respects the actual types of the elements.
e.g. an instance that represents an `array<int>` and another instance that represents an `array<long>` should be considered incompatible, and thus should return false for `equals`.

`GenericArrayData` doesn't keep any schema information by itself, and rather relies on the Java objects referenced by its `array` field's elements to keep track of their own object types. So, the most straightforward way to respect their types is to call `equals` on the elements, instead of using Scala's `==` operator, which can have semantics that are not always desirable:
```
new java.lang.Integer(123) == new java.lang.Long(123L) // true in Scala
new java.lang.Integer(123).equals(new java.lang.Long(123L)) // false in Scala
```

## How was this patch tested?

Added unit test in `ComplexDataSuite`

Author: Kris Mok <kris.mok@databricks.com>

Closes #21643 from rednaxelafx/fix-genericarraydata-equals.
  • Loading branch information
rednaxelafx authored and cloud-fan committed Jun 27, 2018
1 parent 16f2c3e commit 1b9368f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
case _ => if (o1 != o2) {
case _ => if (!o1.equals(o2)) {
return false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,40 @@ class ComplexDataSuite extends SparkFunSuite {
// The copied data should not be changed externally.
assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
}

test("SPARK-24659: GenericArrayData.equals should respect element type differences") {
import scala.reflect.ClassTag

// Expected positive cases
def arraysShouldEqual[T: ClassTag](element: T*): Unit = {
val array1 = new GenericArrayData(Array[T](element: _*))
val array2 = new GenericArrayData(Array[T](element: _*))
assert(array1.equals(array2))
}
arraysShouldEqual(true, false) // Boolean
arraysShouldEqual(0.toByte, 123.toByte, Byte.MinValue, Byte.MaxValue) // Byte
arraysShouldEqual(0.toShort, 123.toShort, Short.MinValue, Short.MaxValue) // Short
arraysShouldEqual(0, 123, -65536, Int.MinValue, Int.MaxValue) // Int
arraysShouldEqual(0L, 123L, -65536L, Long.MinValue, Long.MaxValue) // Long
arraysShouldEqual(0.0F, 123.0F, Float.MinValue, Float.MaxValue, Float.MinPositiveValue,
Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN) // Float
arraysShouldEqual(0.0, 123.0, Double.MinValue, Double.MaxValue, Double.MinPositiveValue,
Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN) // Double
arraysShouldEqual(Array[Byte](123.toByte), Array[Byte](), null) // SQL Binary
arraysShouldEqual(UTF8String.fromString("foo"), null) // SQL String

// Expected negative cases
// Spark SQL considers cases like array<int> vs array<long> to be incompatible,
// so an underlying implementation of array type should return false in such cases.
def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = {
val array1 = new GenericArrayData(Array[T](element1))
val array2 = new GenericArrayData(Array[U](element2))
assert(!array1.equals(array2))
}
arraysShouldNotEqual(true, 1) // Boolean <-> Int
arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int
arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long
arraysShouldNotEqual(123.toShort, 123) // Short <-> Int
arraysShouldNotEqual(123, 123L) // Int <-> Long
}
}

0 comments on commit 1b9368f

Please sign in to comment.