Skip to content

Commit

Permalink
[SPARK-45564][SQL] Simplify 'DataFrameStatFunctions.bloomFilter' with…
Browse files Browse the repository at this point in the history
… 'BloomFilterAggregate' expression

### What changes were proposed in this pull request?
Simplify 'DataFrameStatFunctions.bloomFilter' function with 'BloomFilterAggregate' expression

### Why are the changes needed?
existing implementation was based on RDD, and it can be simplified by dataframe operations

### Does this PR introduce _any_ user-facing change?
when the input parameters or datatypes are invalid, throw `AnalysisException` instead of `IllegalArgumentException`

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#43391 from zhengruifeng/sql_reimpl_stat_bloomFilter.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
zhengruifeng authored and srowen committed Oct 17, 2023
1 parent f00ec39 commit 922844f
Showing 1 changed file with 14 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.execution.stat._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -535,7 +537,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
bloomFilter(Column(colName), expectedNumItems, fpp)
}

/**
Expand All @@ -547,7 +549,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): BloomFilter = {
buildBloomFilter(col, expectedNumItems, -1L, fpp)
val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
bloomFilter(col, expectedNumItems, numBits)
}

/**
Expand All @@ -559,7 +562,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* @since 2.0.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
bloomFilter(Column(colName), expectedNumItems, numBits)
}

/**
Expand All @@ -571,57 +574,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
* @since 2.0.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): BloomFilter = {
buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
}

private def buildBloomFilter(col: Column, expectedNumItems: Long,
numBits: Long,
fpp: Double): BloomFilter = {
val singleCol = df.select(col)
val colType = singleCol.schema.head.dataType

require(colType == StringType || colType.isInstanceOf[IntegralType],
s"Bloom filter only supports string type and integral types, but got $colType.")

val updater: (BloomFilter, InternalRow) => Unit = colType match {
// For string type, we can get bytes of our `UTF8String` directly, and call the `putBinary`
// instead of `putString` to avoid unnecessary conversion.
case StringType => (filter, row) => filter.putBinary(row.getUTF8String(0).getBytes)
case ByteType => (filter, row) => filter.putLong(row.getByte(0))
case ShortType => (filter, row) => filter.putLong(row.getShort(0))
case IntegerType => (filter, row) => filter.putLong(row.getInt(0))
case LongType => (filter, row) => filter.putLong(row.getLong(0))
case _ =>
throw new IllegalArgumentException(
s"Bloom filter only supports string type and integral types, " +
s"and does not support type $colType."
)
}

singleCol.queryExecution.toRdd.treeAggregate(null.asInstanceOf[BloomFilter])(
(filter: BloomFilter, row: InternalRow) => {
val theFilter =
if (filter == null) {
if (fpp.isNaN) {
BloomFilter.create(expectedNumItems, numBits)
} else {
BloomFilter.create(expectedNumItems, fpp)
}
} else {
filter
}
updater(theFilter, row)
theFilter
},
(filter1, filter2) => {
if (filter1 == null) {
filter2
} else if (filter2 == null) {
filter1
} else {
filter1.mergeInPlace(filter2)
}
}
val bloomFilterAgg = new BloomFilterAggregate(
col.expr,
Literal(expectedNumItems, LongType),
Literal(numBits, LongType)
)
val bytes = df.select(
Column(bloomFilterAgg.toAggregateExpression(false))
).head().getAs[Array[Byte]](0)
bloomFilterAgg.deserialize(bytes)
}
}

0 comments on commit 922844f

Please sign in to comment.