Skip to content

Commit

Permalink
[SPARK-44154][SQL] Added more unit tests to BitmapExpressionUtilsSuit…
Browse files Browse the repository at this point in the history
…e and made minor improvements to Bitmap Aggregate Expressions

### What changes were proposed in this pull request?
I firstly added more unit tests for the `BITMAT_BIT_POSITION` and `BITMAP_BUCKET_NUMBER` expressions. Secondly, I made a minor improvement in the implementation of the `BITMAP_CONSTRUCT_AGG` and `BUTMAP_OR_AGG` expressions, where I converted `inputAggBufferAttributes` from a method to a value.

### Why are the changes needed?
The unit tests cover more corner cases. Having `inputAggBufferAttributes` as a value makes it so that the AttributeReferences aren't reinitialized every time `inputAggBufferAttributes` is referred to.

### How was this patch tested?
I reran all the tests for Bitmap expressions and they succeeded. The test suites were `BitmapExpressionUtilsSuite` and `BitmapExpressionsQuerySuite`.

Closes apache#42043 from harshmotw-db/harsh-dev.

Authored-by: Harsh Motwani <harsh.motwani@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
harshmotw-db authored and ragnarok56 committed Mar 2, 2024
1 parent 840dae0 commit 6058eba
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,17 @@ case class BitmapConstructAgg(child: Expression,

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

// The aggregation buffer is a fixed size binary.
private val bitmapAttr = AttributeReference("bitmap", BinaryType, nullable = false)()

override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil

override def defaultResult: Option[Literal] =
Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)))

override def inputAggBufferAttributes: Seq[AttributeReference] =
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

// The aggregation buffer is a fixed size binary.
private val bitmapAttr = AttributeReference("bitmap", BinaryType, nullable = false)()

override def initialize(buffer: InternalRow): Unit = {
buffer.update(mutableAggBufferOffset, Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0))
}
Expand Down Expand Up @@ -270,17 +270,17 @@ case class BitmapOrAgg(child: Expression,

override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)

// The aggregation buffer is a fixed size binary.
private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)()

override def aggBufferAttributes: Seq[AttributeReference] = bitmapAttr :: Nil

override def defaultResult: Option[Literal] =
Option(Literal(Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0)))

override def inputAggBufferAttributes: Seq[AttributeReference] =
override val inputAggBufferAttributes: Seq[AttributeReference] =
aggBufferAttributes.map(_.newInstance())

// The aggregation buffer is a fixed size binary.
private val bitmapAttr = AttributeReference("bitmap", BinaryType, false)()

override def initialize(buffer: InternalRow): Unit = {
buffer.update(mutableAggBufferOffset, Array.fill[Byte](BitmapExpressionUtils.NUM_BYTES)(0))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,36 @@ import org.apache.spark.SparkFunSuite
class BitmapExpressionUtilsSuite extends SparkFunSuite {

test("bitmap_bucket_number with positive inputs") {
Seq((0L, 0L), (1L, 1L), (2L, 1L), (3L, 1L),
(32768L, 1L), (32769L, 2L), (32770L, 2L)).foreach {
Seq((0L, 0L), (1L, 1L), (2L, 1L), (3L, 1L), (65537L, 3L), (65536L, 2L), (3232423L, 99L),
(4538345L, 139L), (845894934L, 25815L), (2147483647L, 65536L),
(Long.MaxValue, 281474976710656L), (32768L, 1L), (32769L, 2L), (32770L, 2L)).foreach {
case (input, expected) =>
assert(BitmapExpressionUtils.bitmapBucketNumber(input) == expected)
}
}

test("bitmap_bucket_number with negative inputs") {
Seq((-1L, 0L), (-2L, 0L), (-3L, 0L),
(-32767L, 0L), (-32768L, -1L), (-32769L, -1L)).foreach {
Seq((-1L, 0L), (-2L, 0L), (-3L, 0L), (-65536L, -2L), (65537L, 3L), (-65535L, -1L),
(-3843485L, -117L), (-2147483647L, -65535L), (-2147483648L, -65536L),
(Long.MinValue, -281474976710656L), (Long.MinValue + 1, -281474976710655L), (-32767L, 0L),
(-32768L, -1L), (-32769L, -1L)).foreach {
case (input, expected) =>
assert(BitmapExpressionUtils.bitmapBucketNumber(input) == expected)
}
}

test("bitmap_bit_position with positive inputs") {
Seq((0L, 0L), (1L, 0L), (2L, 1L), (3L, 2L),
Seq((0L, 0L), (1L, 0L), (2L, 1L), (3L, 2L), (65537L, 0L), (65536L, 32767L), (3232423L, 21158L),
(4538345L, 16360L), (845894934L, 21781L), (2147483647L, 32766L), (Long.MaxValue, 32766L),
(32768L, 32767L), (32769L, 0L), (32770L, 1L)).foreach {
case (input, expected) =>
assert(BitmapExpressionUtils.bitmapBitPosition(input) == expected)
}
}

test("bitmap_bit_position with negative inputs") {
Seq((-1L, 1L), (-2L, 2L), (-3L, 3L),
Seq((-1L, 1L), (-2L, 2L), (-3L, 3L), (-65536L, 0L), (-65535L, 32767L), (-3843485L, 9629L),
(-2147483647L, 32767L), (-2147483648L, 0L), (Long.MinValue, 0L), (Long.MinValue + 1, 32767L),
(-32767L, 32767L), (-32768L, 0L), (-32769L, 1L)).foreach {
case (input, expected) =>
assert(BitmapExpressionUtils.bitmapBitPosition(input) == expected)
Expand Down

0 comments on commit 6058eba

Please sign in to comment.