Skip to content

Commit

Permalink
[SPARK-48454][PS] Directly use the parent dataframe class
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Directly use the parent dataframe class

### Why are the changes needed?
code clean up

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
CI

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#46785 from zhengruifeng/ps_df_cleanup.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed May 29, 2024
1 parent 935f092 commit 47c55f4
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 33 deletions.
7 changes: 2 additions & 5 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@
create_tuple_for_frame_type,
)
from pyspark.pandas.plot import PandasOnSparkPlotAccessor
from pyspark.sql.utils import get_dataframe_class

if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
Expand Down Expand Up @@ -529,15 +528,14 @@ class DataFrame(Frame, Generic[T]):
def __init__( # type: ignore[no-untyped-def]
self, data=None, index=None, columns=None, dtype=None, copy=False
):
SparkDataFrame = get_dataframe_class()
index_assigned = False
if isinstance(data, InternalFrame):
assert columns is None
assert dtype is None
assert not copy
if index is None:
internal = data
elif isinstance(data, SparkDataFrame):
elif isinstance(data, PySparkDataFrame):
assert columns is None
assert dtype is None
assert not copy
Expand Down Expand Up @@ -13730,8 +13728,7 @@ def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> Any
"""
Performs a reduction on a spark DataFrame, the functions being known SQL aggregate functions.
"""
SparkDataFrame = get_dataframe_class()
assert isinstance(sdf, SparkDataFrame)
assert isinstance(sdf, PySparkDataFrame)
sdf0 = sdf.agg(*aggs)
lst = sdf0.limit(2).toPandas()
assert len(lst) == 1, (sdf, lst)
Expand Down
6 changes: 2 additions & 4 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@
StructType,
StringType,
)
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.sql.utils import is_remote, get_dataframe_class
from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
from pyspark import pandas as ps
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale
Expand Down Expand Up @@ -620,8 +619,7 @@ def __init__(
>>> internal.column_label_names
[('column_labels_a',), ('column_labels_b',)]
"""
SparkDataFrame = get_dataframe_class()
assert isinstance(spark_frame, SparkDataFrame)
assert isinstance(spark_frame, PySparkDataFrame)
assert not spark_frame.isStreaming, "pandas-on-Spark does not support Structured Streaming."

if not index_spark_columns:
Expand Down
4 changes: 1 addition & 3 deletions python/pyspark/pandas/spark/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pyspark.sql.types import DataType, StructType
from pyspark.pandas._typing import IndexOpsLike
from pyspark.pandas.internal import InternalField
from pyspark.sql.utils import get_dataframe_class

if TYPE_CHECKING:
from pyspark.sql._typing import OptionalPrimitiveType
Expand Down Expand Up @@ -936,8 +935,7 @@ def apply(
2 3 1
"""
output = func(self.frame(index_col))
SparkDataFrame = get_dataframe_class()
if not isinstance(output, SparkDataFrame):
if not isinstance(output, PySparkDataFrame):
raise ValueError(
"The output of the function [%s] should be of a "
"pyspark.sql.DataFrame; however, got [%s]." % (func, type(output))
Expand Down
5 changes: 2 additions & 3 deletions python/pyspark/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

from pyspark.sql import functions as F, Column, DataFrame as PySparkDataFrame, SparkSession
from pyspark.sql.types import DoubleType
from pyspark.sql.utils import is_remote, get_dataframe_class
from pyspark.sql.utils import is_remote
from pyspark.errors import PySparkTypeError
from pyspark import pandas as ps # noqa: F401
from pyspark.pandas._typing import (
Expand Down Expand Up @@ -915,8 +915,7 @@ def verify_temp_column_name(
)
column_name = column_name_or_label

SparkDataFrame = get_dataframe_class()
assert isinstance(df, SparkDataFrame), type(df)
assert isinstance(df, PySparkDataFrame), type(df)
assert (
column_name not in df.columns
), "The given column name `{}` already exists in the Spark DataFrame: {}".format(
Expand Down
11 changes: 0 additions & 11 deletions python/pyspark/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,17 +422,6 @@ def pyspark_column_op(
return result.fillna(fillna) if fillna is not None else result


def get_dataframe_class() -> Type["DataFrame"]:
from pyspark.sql.dataframe import DataFrame as PySparkDataFrame

if is_remote():
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame

return ConnectDataFrame
else:
return PySparkDataFrame


def get_window_class() -> Type["Window"]:
from pyspark.sql.window import Window as PySparkWindow

Expand Down
9 changes: 2 additions & 7 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,7 @@ def assertDataFrameEqual(
actual, expected, almost=True, rtol=rtol, atol=atol, check_row_order=checkRowOrder
)

from pyspark.sql.utils import get_dataframe_class

# if is_remote(), allow Connect DataFrame
SparkDataFrame = get_dataframe_class()

if not isinstance(actual, (DataFrame, SparkDataFrame, list)):
if not isinstance(actual, (DataFrame, list)):
raise PySparkAssertionError(
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
Expand All @@ -843,7 +838,7 @@ def assertDataFrameEqual(
"actual_type": type(actual),
},
)
elif not isinstance(expected, (DataFrame, SparkDataFrame, list)):
elif not isinstance(expected, (DataFrame, list)):
raise PySparkAssertionError(
error_class="INVALID_TYPE_DF_EQUALITY_ARG",
message_parameters={
Expand Down

0 comments on commit 47c55f4

Please sign in to comment.