Skip to content

Commit

Permalink
[SPARK-44749][SQL][PYTHON] Support named arguments in Python UDTF
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Supports named arguments in Python UDTF.

For example:

```py
>>> udtf(returnType="a: int")
... class TestUDTF:
...     def eval(self, a, b):
...         yield a,
...
>>> spark.udtf.register("test_udtf", TestUDTF)

>>> TestUDTF(a=lit(10), b=lit("x")).show()
+---+
|  a|
+---+
| 10|
+---+

>>> TestUDTF(b=lit("x"), a=lit(10)).show()
+---+
|  a|
+---+
| 10|
+---+

>>> spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')").show()
+---+
|  a|
+---+
| 10|
+---+

>>> spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)").show()
+---+
|  a|
+---+
| 10|
+---+
```

or:

```py
>>> udtf
... class TestUDTF:
...     staticmethod
...     def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
...         return AnalyzeResult(
...             StructType(
...                 [StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())]
...             )
...         )
...     def eval(self, **kwargs):
...         yield tuple(value for _, value in sorted(kwargs.items()))
...
>>> spark.udtf.register("test_udtf", TestUDTF)

>>> spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x', x=>100.0)").show()
+---+---+-----+
|  a|  b|    x|
+---+---+-----+
| 10|  x|100.0|
+---+---+-----+

>>> spark.sql("SELECT * FROM test_udtf(x=>10, a=>'x', z=>100.0)").show()
+---+---+-----+
|  a|  x|    z|
+---+---+-----+
|  x| 10|100.0|
+---+---+-----+
```

### Why are the changes needed?

Now that named arguments are supported (#41796, #42020).

It should be supported in Python UDTF.

### Does this PR introduce _any_ user-facing change?

Yes, named arguments will be available for Python UDTF.

### How was this patch tested?

Added related tests.

Closes #42422 from ueshin/issues/SPARK-44749/kwargs.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
ueshin committed Aug 14, 2023
1 parent 0fc85a2 commit d462956
Show file tree
Hide file tree
Showing 18 changed files with 472 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ message Expression {
UnresolvedNamedLambdaVariable unresolved_named_lambda_variable = 14;
CommonInlineUserDefinedFunction common_inline_user_defined_function = 15;
CallFunction call_function = 16;
NamedArgumentExpression named_argument_expression = 17;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// relations they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -380,3 +381,11 @@ message CallFunction {
// (Optional) Function arguments. Empty arguments are allowed.
repeated Expression arguments = 2;
}

message NamedArgumentExpression {
// (Required) The key of the named argument.
string key = 1;

// (Required) The value expression of the named argument.
Expression value = 2;
}
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,8 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
transformCommonInlineUserDefinedFunction(exp.getCommonInlineUserDefinedFunction)
case proto.Expression.ExprTypeCase.CALL_FUNCTION =>
transformCallFunction(exp.getCallFunction)
case proto.Expression.ExprTypeCase.NAMED_ARGUMENT_EXPRESSION =>
transformNamedArgumentExpression(exp.getNamedArgumentExpression)
case _ =>
throw InvalidPlanInput(
s"Expression with ID: ${exp.getExprTypeCase.getNumber} is not supported")
Expand Down Expand Up @@ -1505,6 +1507,11 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
false)
}

private def transformNamedArgumentExpression(
namedArg: proto.NamedArgumentExpression): Expression = {
NamedArgumentExpression(namedArg.getKey, transformExpression(namedArg.getValue))
}

private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket = {
unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf)
}
Expand Down
14 changes: 14 additions & 0 deletions python/pyspark/sql/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,23 @@ def _to_java_expr(col: "ColumnOrName") -> JavaObject:
return _to_java_column(col).expr()


@overload
def _to_seq(sc: SparkContext, cols: Iterable[JavaObject]) -> JavaObject:
pass


@overload
def _to_seq(
sc: SparkContext,
cols: Iterable["ColumnOrName"],
converter: Optional[Callable[["ColumnOrName"], JavaObject]],
) -> JavaObject:
pass


def _to_seq(
sc: SparkContext,
cols: Union[Iterable["ColumnOrName"], Iterable[JavaObject]],
converter: Optional[Callable[["ColumnOrName"], JavaObject]] = None,
) -> JavaObject:
"""
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,3 +1054,23 @@ def __repr__(self) -> str:
return f"CallFunction('{self._name}', {', '.join([str(arg) for arg in self._args])})"
else:
return f"CallFunction('{self._name}')"


class NamedArgumentExpression(Expression):
def __init__(self, key: str, value: Expression):
super().__init__()

assert isinstance(key, str)
self._key = key

assert isinstance(value, Expression)
self._value = value

def to_plan(self, session: "SparkConnectClient") -> "proto.Expression":
expr = proto.Expression()
expr.named_argument_expression.key = self._key
expr.named_argument_expression.value.CopyFrom(self._value.to_plan(session))
return expr

def __repr__(self) -> str:
return f"{self._key} => {self._value}"
122 changes: 62 additions & 60 deletions python/pyspark/sql/connect/proto/expressions_pb2.py

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions python/pyspark/sql/connect/proto/expressions_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,7 @@ class Expression(google.protobuf.message.Message):
UNRESOLVED_NAMED_LAMBDA_VARIABLE_FIELD_NUMBER: builtins.int
COMMON_INLINE_USER_DEFINED_FUNCTION_FIELD_NUMBER: builtins.int
CALL_FUNCTION_FIELD_NUMBER: builtins.int
NAMED_ARGUMENT_EXPRESSION_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def literal(self) -> global___Expression.Literal: ...
Expand Down Expand Up @@ -1138,6 +1139,8 @@ class Expression(google.protobuf.message.Message):
@property
def call_function(self) -> global___CallFunction: ...
@property
def named_argument_expression(self) -> global___NamedArgumentExpression: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
relations they can add them here. During the planning the correct resolution is done.
Expand All @@ -1162,6 +1165,7 @@ class Expression(google.protobuf.message.Message):
| None = ...,
common_inline_user_defined_function: global___CommonInlineUserDefinedFunction | None = ...,
call_function: global___CallFunction | None = ...,
named_argument_expression: global___NamedArgumentExpression | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
Expand All @@ -1185,6 +1189,8 @@ class Expression(google.protobuf.message.Message):
b"lambda_function",
"literal",
b"literal",
"named_argument_expression",
b"named_argument_expression",
"sort_order",
b"sort_order",
"unresolved_attribute",
Expand Down Expand Up @@ -1226,6 +1232,8 @@ class Expression(google.protobuf.message.Message):
b"lambda_function",
"literal",
b"literal",
"named_argument_expression",
b"named_argument_expression",
"sort_order",
b"sort_order",
"unresolved_attribute",
Expand Down Expand Up @@ -1265,6 +1273,7 @@ class Expression(google.protobuf.message.Message):
"unresolved_named_lambda_variable",
"common_inline_user_defined_function",
"call_function",
"named_argument_expression",
"extension",
] | None: ...

Expand Down Expand Up @@ -1505,3 +1514,28 @@ class CallFunction(google.protobuf.message.Message):
) -> None: ...

global___CallFunction = CallFunction

class NamedArgumentExpression(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.str
"""(Required) The key of the named argument."""
@property
def value(self) -> global___Expression:
"""(Required) The value expression of the named argument."""
def __init__(
self,
*,
key: builtins.str = ...,
value: global___Expression | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["value", b"value"]
) -> builtins.bool: ...
def ClearField(
self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
) -> None: ...

global___NamedArgumentExpression = NamedArgumentExpression
18 changes: 10 additions & 8 deletions python/pyspark/sql/connect/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
check_dependencies(__name__)

import warnings
from typing import Type, TYPE_CHECKING, Optional, Union
from typing import List, Type, TYPE_CHECKING, Optional, Union

from pyspark.rdd import PythonEvalType
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import ColumnReference
from pyspark.sql.connect.expressions import ColumnReference, Expression, NamedArgumentExpression
from pyspark.sql.connect.plan import (
CommonInlineUserDefinedTableFunction,
PythonUDTF,
Expand Down Expand Up @@ -146,12 +146,14 @@ def __init__(
self.deterministic = deterministic

def _build_common_inline_user_defined_table_function(
self, *cols: "ColumnOrName"
self, *args: "ColumnOrName", **kwargs: "ColumnOrName"
) -> CommonInlineUserDefinedTableFunction:
arg_cols = [
col if isinstance(col, Column) else Column(ColumnReference(col)) for col in cols
def to_expr(col: "ColumnOrName") -> Expression:
return col._expr if isinstance(col, Column) else ColumnReference(col)

arg_exprs: List[Expression] = [to_expr(arg) for arg in args] + [
NamedArgumentExpression(key, to_expr(value)) for key, value in kwargs.items()
]
arg_exprs = [col._expr for col in arg_cols]

udtf = PythonUDTF(
func=self.func,
Expand All @@ -166,13 +168,13 @@ def _build_common_inline_user_defined_table_function(
arguments=arg_exprs,
)

def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
def __call__(self, *args: "ColumnOrName", **kwargs: "ColumnOrName") -> "DataFrame":
from pyspark.sql.connect.session import SparkSession
from pyspark.sql.connect.dataframe import DataFrame

session = SparkSession.active()

plan = self._build_common_inline_user_defined_table_function(*cols)
plan = self._build_common_inline_user_defined_table_function(*args, **kwargs)
return DataFrame.withPlan(plan, session)

def asNondeterministic(self) -> "UserDefinedTableFunction":
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15547,6 +15547,12 @@ def udtf(
.. versionadded:: 3.5.0
.. versionchanged:: 4.0.0
Supports Python side analysis.
.. versionchanged:: 4.0.0
Supports keyword-arguments.
Parameters
----------
cls : class
Expand Down Expand Up @@ -15623,6 +15629,38 @@ def udtf(
| 1| x|
+---+---+
UDTF can use keyword arguments:
>>> @udtf
... class TestUDTFWithKwargs:
... @staticmethod
... def analyze(
... a: AnalyzeArgument, b: AnalyzeArgument, **kwargs: AnalyzeArgument
... ) -> AnalyzeResult:
... return AnalyzeResult(
... StructType().add("a", a.data_type)
... .add("b", b.data_type)
... .add("x", kwargs["x"].data_type)
... )
...
... def eval(self, a, b, **kwargs):
... yield a, b, kwargs["x"]
...
>>> TestUDTFWithKwargs(lit(1), x=lit("x"), b=lit("b")).show()
+---+---+---+
| a| b| x|
+---+---+---+
| 1| b| x|
+---+---+---+
>>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs)
>>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show()
+---+---+---+
| a| b| x|
+---+---+---+
| 1| b| x|
+---+---+---+
Arrow optimization can be explicitly enabled when creating UDTFs:
>>> @udtf(returnType="c1: int, c2: int", useArrow=True)
Expand Down
88 changes: 88 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,7 @@ def eval(self, **kwargs):
expected = [Row(c1="hello", c2="world")]
assertDataFrameEqual(TestUDTF(), expected)
assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf()"), expected)
assertDataFrameEqual(self.spark.sql("SELECT * FROM test_udtf(a=>1)"), expected)

with self.assertRaisesRegex(
AnalysisException, r"analyze\(\) takes 0 positional arguments but 1 was given"
Expand Down Expand Up @@ -1795,6 +1796,93 @@ def terminate(self):
assertSchemaEqual(df.schema, StructType().add("col1", IntegerType()))
assertDataFrameEqual(df, [Row(col1=10), Row(col1=100)])

def test_udtf_with_named_arguments(self):
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a, b):
yield a,

self.spark.udtf.register("test_udtf", TestUDTF)

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10)])

def test_udtf_with_named_arguments_negative(self):
@udtf(returnType="a: int")
class TestUDTF:
def eval(self, a, b):
yield a,

self.spark.udtf.register("test_udtf", TestUDTF)

with self.assertRaisesRegex(
AnalysisException,
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
):
self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show()

with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show()

with self.assertRaisesRegex(
PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
):
self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show()

def test_udtf_with_kwargs(self):
@udtf(returnType="a: int, b: string")
class TestUDTF:
def eval(self, **kwargs):
yield kwargs["a"], kwargs["b"]

self.spark.udtf.register("test_udtf", TestUDTF)

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="x")])

def test_udtf_with_analyze_kwargs(self):
@udtf
class TestUDTF:
@staticmethod
def analyze(**kwargs: AnalyzeArgument) -> AnalyzeResult:
return AnalyzeResult(
StructType(
[StructField(key, arg.data_type) for key, arg in sorted(kwargs.items())]
)
)

def eval(self, **kwargs):
yield tuple(value for _, value in sorted(kwargs.items()))

self.spark.udtf.register("test_udtf", TestUDTF)

for i, df in enumerate(
[
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
TestUDTF(a=lit(10), b=lit("x")),
TestUDTF(b=lit("x"), a=lit(10)),
]
):
with self.subTest(query_no=i):
assertDataFrameEqual(df, [Row(a=10, b="x")])


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
Loading

0 comments on commit d462956

Please sign in to comment.