Skip to content

Commit

Permalink
[SPARK-44952][SQL][PYTHON] Support named arguments in aggregate Panda…
Browse files Browse the repository at this point in the history
…s UDFs

### What changes were proposed in this pull request?

Supports named arguments in aggregate Pandas UDFs.

For example:

```py
>>> pandas_udf("double")
... def weighted_mean(v: pd.Series, w: pd.Series) -> float:
...     import numpy as np
...     return np.average(v, weights=w)
...
>>> df = spark.createDataFrame(
...     [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
...     ("id", "v", "w"))

>>> df.groupby("id").agg(weighted_mean(v=df["v"], w=df["w"])).show()
+---+-----------------------------+
| id|weighted_mean(v => v, w => w)|
+---+-----------------------------+
|  1|           1.6666666666666667|
|  2|            7.166666666666667|
+---+-----------------------------+

>>> df.groupby("id").agg(weighted_mean(w=df["w"], v=df["v"])).show()
+---+-----------------------------+
| id|weighted_mean(w => w, v => v)|
+---+-----------------------------+
|  1|           1.6666666666666667|
|  2|            7.166666666666667|
+---+-----------------------------+
```

or with window:

```py
>>> w = Window.partitionBy("id").orderBy("v").rowsBetween(-2, 1)

>>> df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)).show()
+---+----+---+------------------+
| id|   v|  w|                wm|
+---+----+---+------------------+
|  1| 1.0|1.0|1.6666666666666667|
|  1| 2.0|2.0|1.6666666666666667|
|  2| 3.0|1.0| 4.333333333333333|
|  2| 5.0|2.0| 7.166666666666667|
|  2|10.0|3.0| 7.166666666666667|
+---+----+---+------------------+

>>> df.withColumn("wm", weighted_mean_udf(w=df.w, v=df.v).over(w)).show()
+---+----+---+------------------+
| id|   v|  w|                wm|
+---+----+---+------------------+
|  1| 1.0|1.0|1.6666666666666667|
|  1| 2.0|2.0|1.6666666666666667|
|  2| 3.0|1.0| 4.333333333333333|
|  2| 5.0|2.0| 7.166666666666667|
|  2|10.0|3.0| 7.166666666666667|
+---+----+---+------------------+
```

### Why are the changes needed?

Now that named arguments support was added (apache#41796, apache#42020).

Aggregate Pandas UDFs can support it.

### Does this PR introduce _any_ user-facing change?

Yes, named arguments will be available for aggregate Pandas UDFs.

### How was this patch tested?

Added related tests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#42663 from ueshin/issues/SPARK-44952/kwargs.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
ueshin committed Sep 1, 2023
1 parent e86849a commit df534c3
Show file tree
Hide file tree
Showing 10 changed files with 429 additions and 40 deletions.
20 changes: 19 additions & 1 deletion python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
Supports Spark Connect.
.. versionchanged:: 4.0.0
Supports keyword-arguments in SCALAR type.
Supports keyword-arguments in SCALAR and GROUPED_AGG type.
Parameters
----------
Expand Down Expand Up @@ -267,6 +267,24 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
| 2| 6.0|
+---+-----------+
This type of Pandas UDF can use keyword arguments:
>>> @pandas_udf("double")
... def weighted_mean_udf(v: pd.Series, w: pd.Series) -> float:
... import numpy as np
... return np.average(v, weights=w)
...
>>> df = spark.createDataFrame(
... [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
... ("id", "v", "w"))
>>> df.groupby("id").agg(weighted_mean_udf(w=df["w"], v=df["v"])).show()
+---+---------------------------------+
| id|weighted_mean_udf(w => w, v => v)|
+---+---------------------------------+
| 1| 1.6666666666666667|
| 2| 7.166666666666667|
+---+---------------------------------+
This UDF can also be used as window functions as below:
>>> from pyspark.sql import Window
Expand Down
147 changes: 145 additions & 2 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
PandasUDFType,
)
from pyspark.sql.types import ArrayType, YearMonthIntervalType
from pyspark.errors import AnalysisException, PySparkNotImplementedError
from pyspark.errors import AnalysisException, PySparkNotImplementedError, PythonException
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
from pyspark.testing.utils import QuietTest, assertDataFrameEqual


if have_pandas:
Expand Down Expand Up @@ -575,6 +575,149 @@ def mean(x):

assert filtered.collect()[0]["mean"] == 42.0

def test_named_arguments(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql("SELECT id, weighted_mean(v, w => w) as wm FROM v GROUP BY id"),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

def test_named_arguments_negative(self):
df = self.data
weighted_mean = self.pandas_agg_weighted_mean_udf

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got an unexpected keyword argument 'x'"
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, x => w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(
PythonException, r"weighted_mean\(\) got multiple values for argument 'v'"
):
self.spark.sql(
"SELECT id, weighted_mean(v, v => w) as wm FROM v GROUP BY id"
).show()

def test_kwargs(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def weighted_mean(**kwargs):
import numpy as np

return np.average(kwargs["v"], weights=kwargs["w"])

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("weighted_mean", weighted_mean)

for i, aggregated in enumerate(
[
df.groupby("id").agg(weighted_mean(v=df.v, w=df.w).alias("wm")),
df.groupby("id").agg(weighted_mean(w=df.w, v=df.v).alias("wm")),
self.spark.sql(
"SELECT id, weighted_mean(v => v, w => w) as wm FROM v GROUP BY id"
),
self.spark.sql(
"SELECT id, weighted_mean(w => w, v => v) as wm FROM v GROUP BY id"
),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

# negative
with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
).show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql(
"SELECT id, weighted_mean(v => v, w) as wm FROM v GROUP BY id"
).show()

def test_named_arguments_and_defaults(self):
df = self.data

@pandas_udf("double", PandasUDFType.GROUPED_AGG)
def biased_sum(v, w=None):
return v.sum() + (w.sum() if w is not None else 100)

with self.tempView("v"):
df.createOrReplaceTempView("v")
self.spark.udf.register("biased_sum", biased_sum)

# without "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=False, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + lit(100)).alias("s"))
)

# with "w"
for i, aggregated in enumerate(
[
df.groupby("id").agg(biased_sum(df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(v=df.v, w=df.w).alias("s")),
df.groupby("id").agg(biased_sum(w=df.w, v=df.v).alias("s")),
self.spark.sql("SELECT id, biased_sum(v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(v => v, w => w) as s FROM v GROUP BY id"),
self.spark.sql("SELECT id, biased_sum(w => w, v => v) as s FROM v GROUP BY id"),
]
):
with self.subTest(with_w=True, query_no=i):
assertDataFrameEqual(
aggregated, df.groupby("id").agg((sum(df.v) + sum(df.w)).alias("s"))
)


class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase):
pass
Expand Down
Loading

0 comments on commit df534c3

Please sign in to comment.