Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24659][SQL] GenericArrayData.equals should respect element type differences #21643

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) {
return false
}
case _ => if (o1 != o2) {
case _ => if (!o1.equals(o2)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any needs to handle Array[Byte] separately above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in java byte[] or other primitive arrays doesn't have a proper equals implementation.

scala> Array(1) == Array(1)
res0: Boolean = false

return false
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,38 @@ class ComplexDataSuite extends SparkFunSuite {
// The copied data should not be changed externally.
assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
}

test("SPARK-24659: GenericArrayData.equals should respect element type differences") {
import scala.reflect.ClassTag
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can move this import to the head of this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion! I'm used to making one-off imports inside a function when an import is only used within that function, so that the scope is as narrow as possible without being disturbing.
Are there any Spark coding style guidelines that suggest otherwise? If so I'll follow the guideline and always import at the beginning of the file.


// Expected positive cases
def arraysShouldEqual[T: ClassTag](element: T*): Unit = {
val array1 = new GenericArrayData(Array[T](element: _*))
val array2 = new GenericArrayData(Array[T](element: _*))
assert(array1.equals(array2))
}
arraysShouldEqual(true, false) // Boolean
arraysShouldEqual(0.toByte, 123.toByte, (-123).toByte) // Byte
arraysShouldEqual(0.toShort, 123.toShort, (-256).toShort) // Short
arraysShouldEqual(0, 123, -65536) // Int
arraysShouldEqual(0L, 123L, -65536L) // Long
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not important but if you are checking corner cases, probably, it makes sense to pass values like Long.MinValue and Double.MaxValue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good one. I can do that (and NaNs/Infinity for floating point types too)

arraysShouldEqual(0.0F, 123.0F, -65536.0F) // Float
arraysShouldEqual(0.0, 123.0, -65536.0) // Double
arraysShouldEqual(Array[Byte](123.toByte), null) // Binary (Array[Byte])
arraysShouldEqual(UTF8String.fromString("foo"), null) // String (UTF8String)

// Expected negative cases
// Spark SQL considers cases like array<int> vs array<long> to be incompatible,
// so an underlying implementation of array type should return false in such cases.
def arraysShouldNotEqual[T: ClassTag, U: ClassTag](element1: T, element2: U): Unit = {
val array1 = new GenericArrayData(Array[T](element1))
val array2 = new GenericArrayData(Array[U](element2))
assert(!array1.equals(array2))
}
arraysShouldNotEqual(true, 1) // Boolean <-> Int
arraysShouldNotEqual(123.toByte, 123) // Byte <-> Int
arraysShouldNotEqual(123.toByte, 123L) // Byte <-> Long
arraysShouldNotEqual(123.toShort, 123) // Short <-> Int
arraysShouldNotEqual(123, 123L) // Int <-> Long
}
}