diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index ad9fdac970639..652129180df94 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): Supports Spark Connect. .. versionchanged:: 4.0.0 - Supports keyword-arguments in SCALAR type. + Supports keyword-arguments in SCALAR and GROUPED_AGG type. Parameters ---------- @@ -267,6 +267,24 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: | 2| 6.0| +---+-----------+ + This type of Pandas UDF can use keyword arguments: + + >>> @pandas_udf("double") + ... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float: + ... import numpy as np + ... return np.average(v, weights=w) + ... + >>> df = spark.createDataFrame( + ... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)], + ... ("id", "v", "w")) + >>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], v=df["v"])).show() + +---+---------------------------------+ + | id|weighted_mean_udf(w => w, v => v)| + +---+---------------------------------+ + | 1| 1.6666666666666667| + | 2| 7.166666666666667| + +---+---------------------------------+ + This UDF can also be used as window functions as below: >>> from pyspark.sql import Window diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py index f434489a6fb88..b500be7a96957 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py @@ -32,7 +32,7 @@ PandasUDFType, ) from pyspark.sql.types import ArrayType, YearMonthIntervalType -from pyspark.errors import AnalysisException, PySparkNotImplementedError +from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException from pyspark.testing.sqlutils import ( ReusedSQLTestCase, have_pandas, @@ -40,7 +40,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual if have_pandas: @@ -575,6 +575,149 @@ def mean(x): assert filtered.collect()[0]["mean"] == 42.0 + def test_named_arguments(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + for i, aggregated in enumerate( + [ + df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")), + df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")), + df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")), + self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"), + self.spark.sql( + "SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id" + ), + self.spark.sql( + "SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id" + ), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm"))) + + def test_named_arguments_negative(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + "SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql( + "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'" + ): + self.spark.sql( + "SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got multiple values for argument 'v'" + ): + self.spark.sql( + "SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id" + ).show() + + def test_kwargs(self): + df = self.data + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def weighted_mean(**kwargs): + import numpy as np + + return np.average(kwargs["v"], weights=kwargs["w"]) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + for i, aggregated in enumerate( + [ + df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")), + df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")), + self.spark.sql( + "SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id" + ), + self.spark.sql( + "SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id" + ), + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm"))) + + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + "SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id" + ).show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql( + "SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id" + ).show() + + def test_named_arguments_and_defaults(self): + df = self.data + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def biased_sum(v, w=None): + return v.sum() + (w.sum() if w is not None else 100) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("biased_sum", biased_sum) + + # without "w" + for i, aggregated in enumerate( + [ + df.groupby("id").agg(biased_sum(df.v).alias("s")), + df.groupby("id").agg(biased_sum(v=df.v).alias("s")), + self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"), + self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"), + ] + ): + with self.subTest(with_w=False, query_no=i): + assertDataFrameEqual( + aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s")) + ) + + # with "w" + for i, aggregated in enumerate( + [ + df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")), + df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")), + df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")), + self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"), + self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"), + self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"), + ] + ): + with self.subTest(with_w=True, query_no=i): + assertDataFrameEqual( + aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s")) + ) + class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py index e74e3783b1236..6968c0740943e 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py @@ -18,7 +18,7 @@ import unittest from typing import cast -from pyspark.errors import AnalysisException +from pyspark.errors import AnalysisException, PythonException from pyspark.sql.functions import ( array, explode, @@ -40,7 +40,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) -from pyspark.testing.utils import QuietTest +from pyspark.testing.utils import QuietTest, assertDataFrameEqual if have_pandas: from pandas.testing import assert_frame_equal @@ -107,6 +107,16 @@ def min(v): return min + @property + def pandas_agg_weighted_mean_udf(self): + import numpy as np + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def weighted_mean(v, w): + return np.average(v, weights=w) + + return weighted_mean + @property def unbounded_window(self): return ( @@ -394,6 +404,165 @@ def test_bounded_mixed(self): assert_frame_equal(expected1.toPandas(), result1.toPandas()) + def test_named_arguments(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + for w, bound in [(self.sliding_row_window, True), (self.unbounded_window, False)]: + for i, windowed in enumerate( + [ + df.withColumn("wm", weighted_mean(df.v, w=df.w).over(w)), + df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)), + df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)), + ] + ): + with self.subTest(bound=bound, query_no=i): + assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w))) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + for w in [ + "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", + "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + ]: + window_spec = f"PARTITION BY id ORDER BY v {w}" + for i, func_call in enumerate( + [ + "weighted_mean(v, w => w)", + "weighted_mean(v => v, w => w)", + "weighted_mean(w => w, v => v)", + ] + ): + with self.subTest(window_spec=window_spec, query_no=i): + assertDataFrameEqual( + self.spark.sql( + f"SELECT id, {func_call} OVER ({window_spec}) as wm FROM v" + ), + self.spark.sql(f"SELECT id, mean(v) OVER ({window_spec}) as wm FROM v"), + ) + + def test_named_arguments_negative(self): + df = self.data + weighted_mean = self.pandas_agg_weighted_mean_udf + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM v" + + for w in [ + "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", + "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + ]: + window_spec = f"PARTITION BY id ORDER BY v {w}" + with self.subTest(window_spec=window_spec): + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, v => w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, x => w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + PythonException, r"weighted_mean\(\) got multiple values for argument 'v'" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v, v => w)", window_spec=window_spec + ) + ).show() + + def test_kwargs(self): + df = self.data + + @pandas_udf("double", PandasUDFType.GROUPED_AGG) + def weighted_mean(**kwargs): + import numpy as np + + return np.average(kwargs["v"], weights=kwargs["w"]) + + for w, bound in [(self.sliding_row_window, True), (self.unbounded_window, False)]: + for i, windowed in enumerate( + [ + df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)), + df.withColumn("wm", weighted_mean(w=df.w, v=df.v).over(w)), + ] + ): + with self.subTest(bound=bound, query_no=i): + assertDataFrameEqual(windowed, df.withColumn("wm", mean(df.v).over(w))) + + with self.tempView("v"): + df.createOrReplaceTempView("v") + self.spark.udf.register("weighted_mean", weighted_mean) + + base_sql = "SELECT id, {func_call} OVER ({window_spec}) as wm FROM v" + + for w in [ + "ROWS BETWEEN 2 PRECEDING AND 1 FOLLOWING", + "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + ]: + window_spec = f"PARTITION BY id ORDER BY v {w}" + with self.subTest(window_spec=window_spec): + for i, func_call in enumerate( + [ + "weighted_mean(v => v, w => w)", + "weighted_mean(w => w, v => v)", + ] + ): + with self.subTest(query_no=i): + assertDataFrameEqual( + self.spark.sql( + base_sql.format(func_call=func_call, window_spec=window_spec) + ), + self.spark.sql( + base_sql.format(func_call="mean(v)", window_spec=window_spec) + ), + ) + + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, v => w)", window_spec=window_spec + ) + ).show() + + with self.assertRaisesRegex( + AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT" + ): + self.spark.sql( + base_sql.format( + func_call="weighted_mean(v => v, w)", window_spec=window_spec + ) + ).show() + class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase): pass diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index f72bf28823006..32ea05bd00a7f 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -939,6 +939,11 @@ def test_udf(a, b): ): self.spark.sql("SELECT test_udf(c => 'x') FROM range(2)").show() + with self.assertRaisesRegex( + PythonException, r"test_udf\(\) got multiple values for argument 'a'" + ): + self.spark.sql("SELECT test_udf(id, a => id * 10) FROM range(2)").show() + def test_kwargs(self): @udf("int") def test_udf(**kwargs): @@ -957,6 +962,16 @@ def test_udf(**kwargs): with self.subTest(query_no=i): assertDataFrameEqual(df, [Row(0), Row(101)]) + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT test_udf(a => id, a => id * 10) FROM range(2)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT test_udf(a => id, id * 10) FROM range(2)").show() + def test_named_arguments_and_defaults(self): @udf("int") def test_udf(a, b=0): diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index a7545c332e6a0..95e46ba433cb9 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -1848,6 +1848,11 @@ def eval(self, a, b): ): self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show() + with self.assertRaisesRegex( + PythonException, r"eval\(\) got multiple values for argument 'a'" + ): + self.spark.sql("SELECT * FROM test_udtf(10, a => 100)").show() + def test_udtf_with_kwargs(self): @udtf(returnType="a: int, b: string") class TestUDTF: @@ -1867,6 +1872,16 @@ def eval(self, **kwargs): with self.subTest(query_no=i): assertDataFrameEqual(df, [Row(a=10, b="x")]) + # negative + with self.assertRaisesRegex( + AnalysisException, + "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + ): + self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show() + + with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"): + self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show() + def test_udtf_with_analyze_kwargs(self): @udtf class TestUDTF: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 19c8c9c897b8e..d95a5c4672f86 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -452,13 +452,13 @@ def verify_element(result): def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) - def wrapped(*series): + def wrapped(*args, **kwargs): import pandas as pd - result = f(*series) + result = f(*args, **kwargs) return pd.Series([result]) - return lambda *a: (wrapped(*a), arrow_return_type) + return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type) def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index): @@ -484,19 +484,19 @@ def wrap_unbounded_window_agg_pandas_udf(f, return_type): # the scalar value. arrow_return_type = to_arrow_type(return_type) - def wrapped(*series): + def wrapped(*args, **kwargs): import pandas as pd - result = f(*series) - return pd.Series([result]).repeat(len(series[0])) + result = f(*args, **kwargs) + return pd.Series([result]).repeat(len((list(args) + list(kwargs.values()))[0])) - return lambda *a: (wrapped(*a), arrow_return_type) + return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type) def wrap_bounded_window_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) - def wrapped(begin_index, end_index, *series): + def wrapped(begin_index, end_index, *args, **kwargs): import pandas as pd result = [] @@ -521,11 +521,12 @@ def wrapped(begin_index, end_index, *series): # Note: Calling reset_index on the slices will increase the cost # of creating slices by about 100%. Therefore, for performance # reasons we don't do it here. - series_slices = [s.iloc[begin_array[i] : end_array[i]] for s in series] - result.append(f(*series_slices)) + args_slices = [s.iloc[begin_array[i] : end_array[i]] for s in args] + kwargs_slices = {k: s.iloc[begin_array[i] : end_array[i]] for k, s in kwargs.items()} + result.append(f(*args_slices, **kwargs_slices)) return pd.Series(result) - return lambda *a: (wrapped(*a), arrow_return_type) + return lambda *a, **kw: (wrapped(*a, **kw), arrow_return_type) def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): @@ -535,6 +536,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): PythonEvalType.SQL_BATCHED_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, + PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, # The below doesn't support named argument, but shares the same protocol. PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, ): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9a6d9c8b735be..b93f87e77b97f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3003,6 +3003,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // we need to make sure that col1 to col5 are all projected from the child of the Window // operator. val extractedExprMap = mutable.LinkedHashMap.empty[Expression, NamedExpression] + def getOrExtract(key: Expression, value: Expression): Expression = { + extractedExprMap.getOrElseUpdate(key.canonicalized, + Alias(value, s"_w${extractedExprMap.size}")()).toAttribute + } def extractExpr(expr: Expression): Expression = expr match { case ne: NamedExpression => // If a named expression is not in regularExpressions, add it to @@ -3016,11 +3020,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ne case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. + case e: NamedArgumentExpression => + // For NamedArgumentExpression, we extract the value and replace it with + // an AttributeReference (with an internal column name, e.g. "_w0"). + NamedArgumentExpression(e.key, getOrExtract(e, e.value)) case e: Expression => // For other expressions, we extract it and replace it with an AttributeReference (with // an internal column name, e.g. "_w0"). - extractedExprMap.getOrElseUpdate(e.canonicalized, - Alias(e, s"_w${extractedExprMap.size}")()).toAttribute + getOrExtract(e, e) } // Now, we extract regular expressions from expressionsWithWindowFunctions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 73560a596ca58..7e349b665f352 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -109,14 +110,20 @@ case class AggregateInPandasExec( // Also eliminate duplicate UDF inputs. val allInputs = new ArrayBuffer[Expression] val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => + val argMetas = inputs.map { input => input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (allInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(allInputs.indexWhere(_.semanticEquals(value)), key) } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 + allInputs += value + dataTypes += value.dataType + ArgumentMetadata(allInputs.length - 1, key) } }.toArray }.toArray @@ -164,10 +171,10 @@ case class AggregateInPandasExec( rows } - val columnarBatchIter = new ArrowPythonRunner( + val columnarBatchIter = new ArrowPythonWithNamedArgumentRunner( pyFuncs, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, - argOffsets, + argMetas, aggInputSchema, sessionLocalTimeZone, largeVarTypes, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index f576637aa25b7..2fcc428407ecc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -52,7 +52,8 @@ case class UserDefinedPythonFunction( def builder(e: Seq[Expression]): Expression = { if (pythonEvalType == PythonEvalType.SQL_BATCHED_UDF || pythonEvalType ==PythonEvalType.SQL_ARROW_BATCHED_UDF - || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF) { + || pythonEvalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF + || pythonEvalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) { /* * Check if the named arguments: * - don't have duplicated names diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala index a32d892622b4c..cf9f8c22ea082 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasEvaluatorFactory.scala @@ -25,11 +25,12 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{JobArtifactSet, PartitionEvaluator, PartitionEvaluatorFactory, SparkEnv, TaskContext} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, EmptyRow, Expression, JoinedRow, NamedArgumentExpression, NamedExpression, PythonFuncExpression, PythonUDAF, SortOrder, SpecificInternalRow, UnsafeProjection, UnsafeRow, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.execution.python.EvalPythonExec.ArgumentMetadata import org.apache.spark.sql.execution.window.{SlidingWindowFunctionFrame, UnboundedFollowingWindowFunctionFrame, UnboundedPrecedingWindowFunctionFrame, UnboundedWindowFunctionFrame, WindowEvaluatorFactoryBase, WindowFunctionFrame} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType, StructField, StructType} @@ -170,14 +171,20 @@ class WindowInPandasEvaluatorFactory( // handles UDF inputs. private val dataInputs = new ArrayBuffer[Expression] private val dataInputTypes = new ArrayBuffer[DataType] - private val argOffsets = inputs.map { input => + private val argMetas = inputs.map { input => input.map { e => - if (dataInputs.exists(_.semanticEquals(e))) { - dataInputs.indexWhere(_.semanticEquals(e)) + val (key, value) = e match { + case NamedArgumentExpression(key, value) => + (Some(key), value) + case _ => + (None, e) + } + if (dataInputs.exists(_.semanticEquals(value))) { + ArgumentMetadata(dataInputs.indexWhere(_.semanticEquals(value)), key) } else { - dataInputs += e - dataInputTypes += e.dataType - dataInputs.length - 1 + dataInputs += value + dataInputTypes += value.dataType + ArgumentMetadata(dataInputs.length - 1, key) } }.toArray }.toArray @@ -206,11 +213,15 @@ class WindowInPandasEvaluatorFactory( pyFuncs.indices.foreach { exprIndex => val frameIndex = expressionIndexToFrameIndex(exprIndex) if (isBounded(frameIndex)) { - argOffsets(exprIndex) = - Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++ - argOffsets(exprIndex).map(_ + windowBoundsInput.length) + argMetas(exprIndex) = + Array( + ArgumentMetadata(lowerBoundIndex(frameIndex), None), + ArgumentMetadata(upperBoundIndex(frameIndex), None)) ++ + argMetas(exprIndex).map( + meta => ArgumentMetadata(meta.offset + windowBoundsInput.length, meta.name)) } else { - argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length) + argMetas(exprIndex) = argMetas(exprIndex).map( + meta => ArgumentMetadata(meta.offset + windowBoundsInput.length, meta.name)) } } @@ -346,10 +357,10 @@ class WindowInPandasEvaluatorFactory( } } - val windowFunctionResult = new ArrowPythonRunner( + val windowFunctionResult = new ArrowPythonWithNamedArgumentRunner( pyFuncs, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, - argOffsets, + argMetas, pythonInputSchema, sessionLocalTimeZone, largeVarTypes,