Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44749][SQL][PYTHON] Support named arguments in Python UDTF #42422

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ message Expression {
UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
CallFunction call_function = 16;
NamedArgumentExpression named_argument_expression = 17;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -380,3 +381,11 @@ message CallFunction {
// (Optional) Function arguments. Empty arguments are allowed.
repeated Expression arguments = 2;
}

message NamedArgumentExpression {
// (Required) The key of the named argument.
string key = 1;

// (Required) The value expression of the named argument.
Expression value = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
case proto.Expression.ExprTypeCase.CALL_FUNCTION =>
transformCallFunction(exp.getCallFunction)
case proto.Expression.ExprTypeCase.NAMED_ARGUMENT_EXPRESSION =>
transformNamedArgumentExpression(exp.getNamedArgumentExpression)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
Expand Down Expand Up @@ -1504,6 +1506,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
false)
}

private def transformNamedArgumentExpression(
namedArg: proto.NamedArgumentExpression): Expression = {
NamedArgumentExpression(namedArg.getKey, transformExpression(namedArg.getValue))
}

private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = {
unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf)
}
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,23 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject:
return _to_java_column(col).expr()


@overload
def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject:
pass


@overload
def _to_seq(
sc: SparkContext,
cols: Iterable["ColumnOrName"],
converter: Optional[Callable[["ColumnOrName"], JavaObject]],
) -> JavaObject:
pass


def _to_seq(
sc: SparkContext,
cols: Union[Iterable["ColumnOrName"], Iterable[JavaObject]],
converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
) -> JavaObject:
"""
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,3 +1054,23 @@ def __repr__(self) -> str:
return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})"
else:
return f"CallFunction('{self._name}')"


class NamedArgumentExpression(Expression):
def __init__(self, key: str, value: Expression):
super().__init__()

assert isinstance(key, str)
self._key = key

assert isinstance(value, Expression)
self._value = value

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.named_argument_expression.key = self._key
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
return expr

def __repr__(self) -> str:
return f"{self._key} => {self._value}"
122 changes: 62 additions & 60 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@ class Expression(google.protobuf.message.Message):
UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
CALL_FUNCTION_FIELD_NUMBER: builtins.int
NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
Expand Down Expand Up @@ -1138,6 +1139,8 @@ class Expression(google.protobuf.message.Message):
@property
def call_function(self) -> global___CallFunction: ...
@property
def named_argument_expression(self) -> global___NamedArgumentExpression: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
relations they can add them here. During the planning the correct resolution is done.
Expand All @@ -1162,6 +1165,7 @@ class Expression(google.protobuf.message.Message):
| None = ...,
common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ...,
call_function: global___CallFunction | None = ...,
named_argument_expression: global___NamedArgumentExpression | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
Expand All @@ -1185,6 +1189,8 @@ class Expression(google.protobuf.message.Message):
b"lambda_function",
"literal",
b"literal",
"named_argument_expression",
b"named_argument_expression",
"sort_order",
b"sort_order",
"unresolved_attribute",
Expand Down Expand Up @@ -1226,6 +1232,8 @@ class Expression(google.protobuf.message.Message):
b"lambda_function",
"literal",
b"literal",
"named_argument_expression",
b"named_argument_expression",
"sort_order",
b"sort_order",
"unresolved_attribute",
Expand Down Expand Up @@ -1265,6 +1273,7 @@ class Expression(google.protobuf.message.Message):
"unresolved_named_lambda_variable",
"common_inline_user_defined_function",
"call_function",
"named_argument_expression",
"extension",
] | None: ...

Expand Down Expand Up @@ -1505,3 +1514,28 @@ class CallFunction(google.protobuf.message.Message):
) -> None: ...

global___CallFunction = CallFunction

class NamedArgumentExpression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
"""(Required) The key of the named argument."""
@property
def value(self) -> global___Expression:
"""(Required) The value expression of the named argument."""
def __init__(
self,
*,
key: builtins.str = ...,
value: global___Expression | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["value", b"value"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...

global___NamedArgumentExpression = NamedArgumentExpression
18 changes: 10 additions & 8 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
check_dependencies(__name__)

import warnings
from typing import Type, TYPE_CHECKING, Optional, Union
from typing import List, Type, TYPE_CHECKING, Optional, Union

from pyspark.rdd import PythonEvalType
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import ColumnReference
from pyspark.sql.connect.expressions import ColumnReference, Expression, NamedArgumentExpression
from pyspark.sql.connect.plan import (
CommonInlineUserDefinedTableFunction,
PythonUDTF,
Expand Down Expand Up @@ -146,12 +146,14 @@ def __init__(
self.deterministic = deterministic

def _build_common_inline_user_defined_table_function(
self, *cols: "ColumnOrName"
self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
) -> CommonInlineUserDefinedTableFunction:
arg_cols = [
col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols
def to_expr(col: "ColumnOrName") -> Expression:
return col._expr if isinstance(col, Column) else ColumnReference(col)

arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items()
]
arg_exprs = [col._expr for col in arg_cols]

udtf = PythonUDTF(
func=self.func,
Expand All @@ -166,13 +168,13 @@ def _build_common_inline_user_defined_table_function(
arguments=arg_exprs,
)

def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame":
from pyspark.sql.connect.session import SparkSession
from pyspark.sql.connect.dataframe import DataFrame

session = SparkSession.active()

plan = self._build_common_inline_user_defined_table_function(*cols)
plan = self._build_common_inline_user_defined_table_function(*args, **kwargs)
return DataFrame.withPlan(plan, session)

def asNondeterministic(self) -> "UserDefinedTableFunction":
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15547,6 +15547,12 @@ def udtf(

.. versionadded:: 3.5.0

.. versionchanged:: 4.0.0
Supports Python side analysis.

.. versionchanged:: 4.0.0
Supports keyword-arguments.

Parameters
----------
cls : class
Expand Down Expand Up @@ -15623,6 +15629,38 @@ def udtf(
| 1| x|
+---+---+

UDTF can use keyword arguments:

>>> @udtf
... class TestUDTFWithKwargs:
... @staticmethod
... def analyze(
... a: AnalyzeArgument, b: AnalyzeArgument, **kwargs: AnalyzeArgument
... ) -> AnalyzeResult:
... return AnalyzeResult(
... StructType().add("a", a.data_type)
... .add("b", b.data_type)
... .add("x", kwargs["x"].data_type)
... )
...
... def eval(self, a, b, **kwargs):
... yield a, b, kwargs["x"]
...
>>> TestUDTFWithKwargs(lit(1), x=lit("x"), b=lit("b")).show()
+---+---+---+
| a| b| x|
+---+---+---+
| 1| b| x|
+---+---+---+

>>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs)
>>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show()
ueshin marked this conversation as resolved.
Show resolved Hide resolved
+---+---+---+
| a| b| x|
+---+---+---+
| 1| b| x|
+---+---+---+

Arrow optimization can be explicitly enabled when creating UDTFs:

>>> @udtf(returnType="c1: int, c2: int", useArrow=True)
Expand Down
88 changes: 88 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,7 @@ def eval(self, **kwargs):
expected = [Row(c1="hello", c2="world")]
assertDataFrameEqual(TestUDTF(), expected)
assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected)
assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected)

with self.assertRaisesRegex(
AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given"
Expand Down Expand Up @@ -1795,6 +1796,93 @@ def terminate(self):
assertSchemaEqual(df.schema, StructType().add("col1", IntegerType()))
assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])

def test_udtf_with_named_arguments(self):
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a, b):
yield a,

self.spark.udtf.register("test_udtf", TestUDTF)

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10)])

def test_udtf_with_named_arguments_negative(self):
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a, b):
ueshin marked this conversation as resolved.
Show resolved Hide resolved
yield a,

self.spark.udtf.register("test_udtf", TestUDTF)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show()

with self.assertRaisesRegex(
PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
):
self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show()

def test_udtf_with_kwargs(self):
@udtf(returnType="a: int, b: string")
class TestUDTF:
def eval(self, **kwargs):
yield kwargs["a"], kwargs["b"]

self.spark.udtf.register("test_udtf", TestUDTF)

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would be the error message if the named argument is used incorrectly? For example

  1. duplicated input argument names: a => 10, a => 10
  2. non-existing argument name: c => 10
  3. incorrect combination of positional and named arguments: test_udtf(a => 10, 'x')

I am afraid that if we directly leverage Python's kwargs, the error messages wouldn't be as user-friendly as the SQL function ones.

Copy link
Member Author

@ueshin ueshin Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. So far just rely on the Python's error.

@dtenedor What's the error message like when applying name arguments with the above cases to other functions? Are there any example we can follow here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I believe @learningchess2003 added these checks in [1]. They are currently in the FunctionBuilderBase.scala file in [2]. If we want to reuse those checks, we could be consistent between error messages for Python UDTFs and other Spark functions.

[1] #42020
[2] https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala#L107-L128

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to raise the following errors:

  1. duplicated input argument names: a => 10, a => 10

It will be checked in the analysis phase and an error with the error class DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE will be raised.

  1. non-existing argument name: c => 10

It will be handled in Python runtime and an error will be raised.

...PySparkRuntimeError: [UDTF_EXEC_ERROR] User defined table function encountered an error in the 'eval' method: eval() got an unexpected keyword argument 'c'
  1. incorrect combination of positional and named arguments: test_udtf(a => 10, 'x')

It will be checked in the analysis phase and an error with the error class UNEXPECTED_POSITIONAL_ARGUMENT will be raised.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="x")])

def test_udtf_with_analyze_kwargs(self):
@udtf
class TestUDTF:
@staticmethod
def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
return AnalyzeResult(
StructType(
[StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())]
)
)

def eval(self, **kwargs):
yield tuple(value for _, value in sorted(kwargs.items()))

self.spark.udtf.register("test_udtf", TestUDTF)

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
ueshin marked this conversation as resolved.
Show resolved Hide resolved
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="x")])


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
Loading