Skip to content

Commit

Permalink
[SPARK-49713][PYTHON][FOLLOWUP] Make function count_min_sketch acce…
Browse files Browse the repository at this point in the history
…pt long seed

### What changes were proposed in this pull request?
Make function `count_min_sketch` accept long seed

### Why are the changes needed?
existing implementation only accepts int seed, which is inconsistent with other `ExpressionWithRandomSeed`:

```py
In [3]:     >>> from pyspark.sql import functions as sf
   ...:     >>> spark.range(100).select(
   ...:     ...     sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6, 1111111111111111111))
   ...:     ... ).show(truncate=False)

...
AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "count_min_sketch(id, 1.5, 0.6, 1111111111111111111)" due to data type mismatch: The 4th parameter requires the "INT" type, however "1111111111111111111" has the type "BIGINT". SQLSTATE: 42K09;
'Aggregate [unresolvedalias('hex(count_min_sketch(id#64L, 1.5, 0.6, 1111111111111111111, 0, 0)))]
+- Range (0, 100, step=1, splits=Some(12))
...

```

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
added doctest

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#48223 from zhengruifeng/count_min_sk_long_seed.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
zhengruifeng authored and dongjoon-hyun committed Sep 24, 2024
1 parent dedf5aa commit 55d0233
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 6 deletions.
3 changes: 1 addition & 2 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
StringType,
)
from pyspark.sql.utils import enum_to_value as _enum_to_value
from pyspark.util import JVM_INT_MAX

# The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf
# for code reuse.
Expand Down Expand Up @@ -1130,7 +1129,7 @@ def count_min_sketch(
confidence: Union[Column, float],
seed: Optional[Union[Column, int]] = None,
) -> Column:
_seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed)
_seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed)
return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed)


Expand Down
14 changes: 13 additions & 1 deletion python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6080,7 +6080,19 @@ def count_min_sketch(
|0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032|
+----------------------------------------------------------------------------------------+

Example 3: Using a random seed
Example 3: Using a long seed

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.2, 1111111111111111111))
... ).show(truncate=False)
+----------------------------------------------------------------------------------------+
|hex(count_min_sketch(id, 1.5, 0.2, 1111111111111111111)) |
+----------------------------------------------------------------------------------------+
|00000001000000000000006400000001000000020000000044078BA100000000000000320000000000000032|
+----------------------------------------------------------------------------------------+

Example 4: Using a random seed

>>> from pyspark.sql import functions as sf
>>> spark.range(100).select(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ object functions {
* @since 4.0.0
*/
def count_min_sketch(e: Column, eps: Column, confidence: Column): Column =
count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt))
count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextLong))

private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column =
Column.internalFn("collect_top_k", e, lit(num), lit(reverse))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ case class CountMinSketchAgg(
// Mark as lazy so that they are not evaluated during tree transformation.
private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double]
private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double]
private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int]
private lazy val seed: Int = seedExpression.eval() match {
case i: Int => i
case l: Long => l.toInt
}

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
Expand Down Expand Up @@ -168,7 +171,8 @@ case class CountMinSketchAgg(
copy(inputAggBufferOffset = newInputAggBufferOffset)

override def inputTypes: Seq[AbstractDataType] = {
Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType)
Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType,
TypeCollection(IntegerType, LongType))
}

override def nullable: Boolean = false
Expand Down

0 comments on commit 55d0233

Please sign in to comment.