From 55d0233d19cc52bee91a9619057d9b6f33165a0a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 24 Sep 2024 07:48:23 -0700 Subject: [PATCH] [SPARK-49713][PYTHON][FOLLOWUP] Make function `count_min_sketch` accept long seed ### What changes were proposed in this pull request? Make function `count_min_sketch` accept long seed ### Why are the changes needed? existing implementation only accepts int seed, which is inconsistent with other `ExpressionWithRandomSeed`: ```py In [3]: >>> from pyspark.sql import functions as sf ...: >>> spark.range(100).select( ...: ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6, 1111111111111111111)) ...: ... ).show(truncate=False) ... AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "count_min_sketch(id, 1.5, 0.6, 1111111111111111111)" due to data type mismatch: The 4th parameter requires the "INT" type, however "1111111111111111111" has the type "BIGINT". SQLSTATE: 42K09; 'Aggregate [unresolvedalias('hex(count_min_sketch(id#64L, 1.5, 0.6, 1111111111111111111, 0, 0)))] +- Range (0, 100, step=1, splits=Some(12)) ... ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added doctest ### Was this patch authored or co-authored using generative AI tooling? no Closes #48223 from zhengruifeng/count_min_sk_long_seed. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/connect/functions/builtin.py | 3 +-- python/pyspark/sql/functions/builtin.py | 14 +++++++++++++- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../expressions/aggregate/CountMinSketchAgg.scala | 8 ++++++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 2a39bc6bfddda..6953230f5b42e 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -70,7 +70,6 @@ StringType, ) from pyspark.sql.utils import enum_to_value as _enum_to_value -from pyspark.util import JVM_INT_MAX # The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf # for code reuse. @@ -1130,7 +1129,7 @@ def count_min_sketch( confidence: Union[Column, float], seed: Optional[Union[Column, int]] = None, ) -> Column: - _seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed) + _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed) return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 2688f9daa23a4..09a286fe7c94e 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -6080,7 +6080,19 @@ def count_min_sketch( |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032| +----------------------------------------------------------------------------------------+ - Example 3: Using a random seed + Example 3: Using a long seed + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.2, 1111111111111111111)) + ... ).show(truncate=False) + +----------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.5, 0.2, 1111111111111111111)) | + +----------------------------------------------------------------------------------------+ + |00000001000000000000006400000001000000020000000044078BA100000000000000320000000000000032| + +----------------------------------------------------------------------------------------+ + + Example 4: Using a random seed >>> from pyspark.sql import functions as sf >>> spark.range(100).select( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index d9bceabe88f8f..ab69789c75f50 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -399,7 +399,7 @@ object functions { * @since 4.0.0 */ def count_min_sketch(e: Column, eps: Column, confidence: Column): Column = - count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt)) + count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextLong)) private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index c26c4a9bdfea3..f0a27677628dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -63,7 +63,10 @@ case class CountMinSketchAgg( // Mark as lazy so that they are not evaluated during tree transformation. private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double] private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double] - private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int] + private lazy val seed: Int = seedExpression.eval() match { + case i: Int => i + case l: Long => l.toInt + } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() @@ -168,7 +171,8 @@ case class CountMinSketchAgg( copy(inputAggBufferOffset = newInputAggBufferOffset) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType) + Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, + TypeCollection(IntegerType, LongType)) } override def nullable: Boolean = false