Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44952][SQL][PYTHON] Support named arguments in aggregate Pandas UDFs #42663

Closed
wants to merge 4 commits into from

Conversation

ueshin
Copy link
Member

@ueshin ueshin commented Aug 24, 2023

What changes were proposed in this pull request?

Supports named arguments in aggregate Pandas UDFs.

For example:

>>> @pandas_udf("double")
... def weighted_mean(v: pd.Series, w: pd.Series) -> float:
...     import numpy as np
...     return np.average(v, weights=w)
...
>>> df = spark.createDataFrame(
...     [(1, 1.0, 1.0), (1, 2.0, 2.0), (2, 3.0, 1.0), (2, 5.0, 2.0), (2, 10.0, 3.0)],
...     ("id", "v", "w"))

>>> df.groupby("id").agg(weighted_mean(v=df["v"], w=df["w"])).show()
+---+-----------------------------+
| id|weighted_mean(v => v, w => w)|
+---+-----------------------------+
|  1|           1.6666666666666667|
|  2|            7.166666666666667|
+---+-----------------------------+

>>> df.groupby("id").agg(weighted_mean(w=df["w"], v=df["v"])).show()
+---+-----------------------------+
| id|weighted_mean(w => w, v => v)|
+---+-----------------------------+
|  1|           1.6666666666666667|
|  2|            7.166666666666667|
+---+-----------------------------+

or with window:

>>> w = Window.partitionBy("id").orderBy("v").rowsBetween(-2, 1)

>>> df.withColumn("wm", weighted_mean(v=df.v, w=df.w).over(w)).show()
+---+----+---+------------------+
| id|   v|  w|                wm|
+---+----+---+------------------+
|  1| 1.0|1.0|1.6666666666666667|
|  1| 2.0|2.0|1.6666666666666667|
|  2| 3.0|1.0| 4.333333333333333|
|  2| 5.0|2.0| 7.166666666666667|
|  2|10.0|3.0| 7.166666666666667|
+---+----+---+------------------+

>>> df.withColumn("wm", weighted_mean_udf(w=df.w, v=df.v).over(w)).show()
+---+----+---+------------------+
| id|   v|  w|                wm|
+---+----+---+------------------+
|  1| 1.0|1.0|1.6666666666666667|
|  1| 2.0|2.0|1.6666666666666667|
|  2| 3.0|1.0| 4.333333333333333|
|  2| 5.0|2.0| 7.166666666666667|
|  2|10.0|3.0| 7.166666666666667|
+---+----+---+------------------+

Why are the changes needed?

Now that named arguments support was added (#41796, #42020).

Aggregate Pandas UDFs can support it.

Does this PR introduce any user-facing change?

Yes, named arguments will be available for aggregate Pandas UDFs.

How was this patch tested?

Added related tests.

Was this patch authored or co-authored using generative AI tooling?

No.

@ueshin
Copy link
Member Author

ueshin commented Aug 24, 2023

Comment on lines 3024 to 3030
case e: NamedArgumentExpression =>
// For NamedArgumentExpression, we extract the value and replace it with
// an AttributeReference (with an internal column name, e.g. "_w0").
NamedArgumentExpression(
e.key,
extractedExprMap.getOrElseUpdate(e.canonicalized,
Alias(e.value, s"_w${extractedExprMap.size}")()).toAttribute)
Copy link
Member Author

@ueshin ueshin Aug 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dtenedor Does this change make sense when the named arguments are specified to window functions?
This is trying to keep the named arguments in the window functions side; otherwise, the named arguments will go to Project that will be generated here and it can't be processed because NamedArgumentExpression is marked as Unevaluable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, the reason the NamedArgumentExpression is marked as Unevaluable is because the intention is for the analyzer to match the provided argument (name, value) pairs for a function call and compare them against the expected ordered parameter list (including parameter names and types) of the function signature.

By the end of analysis, these expressions should be gone as a result of rearranging the provided function arguments to match the expected order, if necessary.

We might just want to deduplicate the logic from L3034-3035 into one place that we can reuse here, but the general idea looks right.

[1] https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala#L74-L179

@ueshin ueshin marked this pull request as ready for review August 24, 2023 23:28
@ueshin ueshin requested a review from dtenedor August 25, 2023 22:54
Copy link
Contributor

@dtenedor dtenedor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general this implementation LGTM. I tried to think of test cases, I left a couple more ideas in there, but this looks good to go once we add them.

with self.subTest(query_no=i):
assertDataFrameEqual(aggregated, df.groupby("id").agg(mean(df.v).alias("wm")))

def test_named_arguments_negative(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please also port these negative tests to the 'kwargs' Pandas UDF case as well? It would be good to make sure we do the same checks there. Same for the window functions testing.

"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql(
"SELECT id, weighted_mean(v => v, v => w) as wm FROM v GROUP BY id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this test! In addition could we add a negative test case where we provide a positional argument matching the first parameter of the function, then a named argument with the same name as that first parameter? It should return an error in that case. Same for the window functions testing.

@ueshin
Copy link
Member Author

ueshin commented Sep 1, 2023

Thanks! merging to master.

@ueshin ueshin closed this in df534c3 Sep 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants