Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50856][SS][PYTHON][CONNECT] Spark Connect Support for TransformWithStateInPandas In Python #49560

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,7 @@ def __hash__(self):
"pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_scalar",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_grouped_agg",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_window",
"pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified that this test actually run in CI?

Copy link
Contributor Author

@jingz-db jingz-db Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i think so. I got several failed test case for this suite in previous CI run: https://github.com/jingz-db/spark/actions/runs/13039529632/job/36378113583#step:12:4144 which is now fixed, but this verifies the suite is actually running on CI

],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
58 changes: 58 additions & 0 deletions python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pyspark.sql.column import Column
from pyspark.sql.connect.functions import builtin as F
from pyspark.errors import PySparkNotImplementedError, PySparkTypeError
from pyspark.sql.streaming.stateful_processor import StatefulProcessor

if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
Expand Down Expand Up @@ -361,6 +362,63 @@ def applyInPandasWithState(

applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__

def transformWithStateInPandas(
self,
statefulProcessor: StatefulProcessor,
outputStructType: Union[StructType, str],
outputMode: str,
timeMode: str,
initialState: Optional["GroupedData"] = None,
eventTimeColumnName: str = "",
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.streaming.stateful_processor_util import TransformWithStateInPandasUdfUtils

udf_util = TransformWithStateInPandasUdfUtils(statefulProcessor, timeMode)
if initialState is None:
udf_obj = UserDefinedFunction(
udf_util.transformWithStateUDF, # type: ignore
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
initial_state_plan = None
initial_state_grouping_cols = None

else:
self._df._check_same_session(initialState._df)
udf_obj = UserDefinedFunction(
udf_util.transformWithStateWithInitStateUDF, # type: ignore
returnType=outputStructType,
evalType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
)
initial_state_plan = initialState._df._plan
initial_state_grouping_cols = initialState._grouping_cols

output_schema: str = (
outputStructType.json()
if isinstance(outputStructType, StructType)
else outputStructType
)

return DataFrame(
plan.TransformWithStateInPandas(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
output_schema=output_schema,
output_mode=outputMode,
time_mode=timeMode,
event_time_col_name=eventTimeColumnName,
cols=self._df.columns,
initial_state_plan=initial_state_plan,
initial_state_grouping_cols=initial_state_grouping_cols,
),
session=self._df._session,
)

transformWithStateInPandas.__doc__ = PySparkGroupedData.transformWithStateInPandas.__doc__

def applyInArrow(
self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
Expand Down
64 changes: 64 additions & 0 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,6 +2546,70 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return self._with_relations(plan, session)


class TransformWithStateInPandas(LogicalPlan):
"""Logical plan object for a TransformWithStateInPandas."""

def __init__(
self,
child: Optional["LogicalPlan"],
grouping_cols: Sequence[Column],
function: "UserDefinedFunction",
output_schema: str,
output_mode: str,
time_mode: str,
event_time_col_name: str,
cols: List[str],
initial_state_plan: Optional["LogicalPlan"],
initial_state_grouping_cols: Optional[Sequence[Column]],
):
assert isinstance(grouping_cols, list) and all(isinstance(c, Column) for c in grouping_cols)
if initial_state_plan is not None:
assert isinstance(initial_state_grouping_cols, list) and all(
isinstance(c, Column) for c in initial_state_grouping_cols
)
super().__init__(
child, self._collect_references(grouping_cols + initial_state_grouping_cols)
)
else:
super().__init__(child, self._collect_references(grouping_cols))
self._grouping_cols = grouping_cols
self._output_schema = output_schema
self._output_mode = output_mode
self._time_mode = time_mode
self._event_time_col_name = event_time_col_name
self._function = function._build_common_inline_user_defined_function(*cols)
self._initial_state_plan = initial_state_plan
self._initial_state_grouping_cols = initial_state_grouping_cols

def plan(self, session: "SparkConnectClient") -> proto.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.transform_with_state_in_pandas.input.CopyFrom(self._child.plan(session))
plan.transform_with_state_in_pandas.grouping_expressions.extend(
[c.to_plan(session) for c in self._grouping_cols]
)
# fill in initial state related fields
if self._initial_state_plan is not None:
self._initial_state_plan = cast(LogicalPlan, self._initial_state_plan)
plan.transform_with_state_in_pandas.initial_input.CopyFrom(
self._initial_state_plan.plan(session)
)
plan.transform_with_state_in_pandas.initial_grouping_expressions.extend(
[c.to_plan(session) for c in self._initial_state_grouping_cols]
)

plan.transform_with_state_in_pandas.output_schema = self._output_schema
plan.transform_with_state_in_pandas.output_mode = self._output_mode
plan.transform_with_state_in_pandas.time_mode = self._time_mode
plan.transform_with_state_in_pandas.event_time_col_name = self._event_time_col_name
# wrap transformWithStateInPandasUdf in a function
plan.transform_with_state_in_pandas.transform_with_state_udf.CopyFrom(
self._function.to_plan_udf(session)
)

return self._with_relations(plan, session)


class PythonUDTF:
"""Represents a Python user-defined table function."""

Expand Down
334 changes: 168 additions & 166 deletions python/pyspark/sql/connect/proto/relations_pb2.py

Large diffs are not rendered by default.

111 changes: 111 additions & 0 deletions python/pyspark/sql/connect/proto/relations_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class Relation(google.protobuf.message.Message):
TRANSPOSE_FIELD_NUMBER: builtins.int
UNRESOLVED_TABLE_VALUED_FUNCTION_FIELD_NUMBER: builtins.int
LATERAL_JOIN_FIELD_NUMBER: builtins.int
TRANSFORM_WITH_STATE_IN_PANDAS_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
Expand Down Expand Up @@ -216,6 +217,8 @@ class Relation(google.protobuf.message.Message):
@property
def lateral_join(self) -> global___LateralJoin: ...
@property
def transform_with_state_in_pandas(self) -> global___TransformWithStateInPandas: ...
@property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
Expand Down Expand Up @@ -301,6 +304,7 @@ class Relation(google.protobuf.message.Message):
transpose: global___Transpose | None = ...,
unresolved_table_valued_function: global___UnresolvedTableValuedFunction | None = ...,
lateral_join: global___LateralJoin | None = ...,
transform_with_state_in_pandas: global___TransformWithStateInPandas | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
Expand Down Expand Up @@ -424,6 +428,8 @@ class Relation(google.protobuf.message.Message):
b"to_df",
"to_schema",
b"to_schema",
"transform_with_state_in_pandas",
b"transform_with_state_in_pandas",
"transpose",
b"transpose",
"unknown",
Expand Down Expand Up @@ -549,6 +555,8 @@ class Relation(google.protobuf.message.Message):
b"to_df",
"to_schema",
b"to_schema",
"transform_with_state_in_pandas",
b"transform_with_state_in_pandas",
"transpose",
b"transpose",
"unknown",
Expand Down Expand Up @@ -614,6 +622,7 @@ class Relation(google.protobuf.message.Message):
"transpose",
"unresolved_table_valued_function",
"lateral_join",
"transform_with_state_in_pandas",
"fill_na",
"drop_na",
"replace",
Expand Down Expand Up @@ -3930,6 +3939,108 @@ class ApplyInPandasWithState(google.protobuf.message.Message):

global___ApplyInPandasWithState = ApplyInPandasWithState

class TransformWithStateInPandas(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

INPUT_FIELD_NUMBER: builtins.int
GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
TRANSFORM_WITH_STATE_UDF_FIELD_NUMBER: builtins.int
OUTPUT_SCHEMA_FIELD_NUMBER: builtins.int
OUTPUT_MODE_FIELD_NUMBER: builtins.int
TIME_MODE_FIELD_NUMBER: builtins.int
INITIAL_INPUT_FIELD_NUMBER: builtins.int
INITIAL_GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
EVENT_TIME_COL_NAME_FIELD_NUMBER: builtins.int
@property
def input(self) -> global___Relation:
"""(Required) Input relation for transformWithStateInPandas."""
@property
def grouping_expressions(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Required) Expressions for grouping keys."""
@property
def transform_with_state_udf(
self,
) -> pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
"""(Required) Bytes for java serialized user-defined stateful processor."""
output_schema: builtins.str
"""(Required) Schema for the output DataFrame."""
output_mode: builtins.str
"""(Required) The output mode of the function."""
time_mode: builtins.str
"""(Required) Time mode for transformWithStateInPandas"""
@property
def initial_input(self) -> global___Relation:
"""(Optional) Input relation for initial State."""
@property
def initial_grouping_expressions(
self,
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
pyspark.sql.connect.proto.expressions_pb2.Expression
]:
"""(Optional) Expressions for grouping keys of the initial state input relation."""
event_time_col_name: builtins.str
"""(Required) Event time column name. Default to be empty string."""
def __init__(
self,
*,
input: global___Relation | None = ...,
grouping_expressions: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
transform_with_state_udf: pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
| None = ...,
output_schema: builtins.str = ...,
output_mode: builtins.str = ...,
time_mode: builtins.str = ...,
initial_input: global___Relation | None = ...,
initial_grouping_expressions: collections.abc.Iterable[
pyspark.sql.connect.proto.expressions_pb2.Expression
]
| None = ...,
event_time_col_name: builtins.str = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
"initial_input",
b"initial_input",
"input",
b"input",
"transform_with_state_udf",
b"transform_with_state_udf",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
"event_time_col_name",
b"event_time_col_name",
"grouping_expressions",
b"grouping_expressions",
"initial_grouping_expressions",
b"initial_grouping_expressions",
"initial_input",
b"initial_input",
"input",
b"input",
"output_mode",
b"output_mode",
"output_schema",
b"output_schema",
"time_mode",
b"time_mode",
"transform_with_state_udf",
b"transform_with_state_udf",
],
) -> None: ...

global___TransformWithStateInPandas = TransformWithStateInPandas

class CommonInlineUserDefinedTableFunction(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor

Expand Down
Loading