From 47c55f4c9a4432ea09b7250a2cae6c6b1700f023 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 30 May 2024 08:25:38 +0900 Subject: [PATCH] [SPARK-48454][PS] Directly use the parent dataframe class ### What changes were proposed in this pull request? Directly use the parent dataframe class ### Why are the changes needed? code clean up ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #46785 from zhengruifeng/ps_df_cleanup. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/frame.py | 7 ++----- python/pyspark/pandas/internal.py | 6 ++---- python/pyspark/pandas/spark/accessors.py | 4 +--- python/pyspark/pandas/utils.py | 5 ++--- python/pyspark/sql/utils.py | 11 ----------- python/pyspark/testing/utils.py | 9 ++------- 6 files changed, 9 insertions(+), 33 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 1a69e88a81d04..52f7a327b5be0 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -149,7 +149,6 @@ create_tuple_for_frame_type, ) from pyspark.pandas.plot import PandasOnSparkPlotAccessor -from pyspark.sql.utils import get_dataframe_class if TYPE_CHECKING: from pyspark.sql._typing import OptionalPrimitiveType @@ -529,7 +528,6 @@ class DataFrame(Frame, Generic[T]): def __init__( # type: ignore[no-untyped-def] self, data=None, index=None, columns=None, dtype=None, copy=False ): - SparkDataFrame = get_dataframe_class() index_assigned = False if isinstance(data, InternalFrame): assert columns is None @@ -537,7 +535,7 @@ def __init__( # type: ignore[no-untyped-def] assert not copy if index is None: internal = data - elif isinstance(data, SparkDataFrame): + elif isinstance(data, PySparkDataFrame): assert columns is None assert dtype is None assert not copy @@ -13730,8 +13728,7 @@ def _reduce_spark_multi(sdf: PySparkDataFrame, aggs: List[PySparkColumn]) -> Any """ Performs a reduction on a spark DataFrame, the functions being known SQL aggregate functions. """ - SparkDataFrame = get_dataframe_class() - assert isinstance(sdf, SparkDataFrame) + assert isinstance(sdf, PySparkDataFrame) sdf0 = sdf.agg(*aggs) lst = sdf0.limit(2).toPandas() assert len(lst) == 1, (sdf, lst) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 04285aa2d879d..c5fef3b138254 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -41,8 +41,7 @@ StructType, StringType, ) -from pyspark.sql.utils import is_timestamp_ntz_preferred -from pyspark.sql.utils import is_remote, get_dataframe_class +from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote from pyspark import pandas as ps from pyspark.pandas._typing import Label from pyspark.pandas.spark.utils import as_nullable_spark_type, force_decimal_precision_scale @@ -620,8 +619,7 @@ def __init__( >>> internal.column_label_names [('column_labels_a',), ('column_labels_b',)] """ - SparkDataFrame = get_dataframe_class() - assert isinstance(spark_frame, SparkDataFrame) + assert isinstance(spark_frame, PySparkDataFrame) assert not spark_frame.isStreaming, "pandas-on-Spark does not support Structured Streaming." if not index_spark_columns: diff --git a/python/pyspark/pandas/spark/accessors.py b/python/pyspark/pandas/spark/accessors.py index b73d24b12d9da..7f3041cf79c7c 100644 --- a/python/pyspark/pandas/spark/accessors.py +++ b/python/pyspark/pandas/spark/accessors.py @@ -27,7 +27,6 @@ from pyspark.sql.types import DataType, StructType from pyspark.pandas._typing import IndexOpsLike from pyspark.pandas.internal import InternalField -from pyspark.sql.utils import get_dataframe_class if TYPE_CHECKING: from pyspark.sql._typing import OptionalPrimitiveType @@ -936,8 +935,7 @@ def apply( 2 3 1 """ output = func(self.frame(index_col)) - SparkDataFrame = get_dataframe_class() - if not isinstance(output, SparkDataFrame): + if not isinstance(output, PySparkDataFrame): raise ValueError( "The output of the function [%s] should be of a " "pyspark.sql.DataFrame; however, got [%s]." % (func, type(output)) diff --git a/python/pyspark/pandas/utils.py b/python/pyspark/pandas/utils.py index 0fe2944bcabe9..fec45072cf93a 100644 --- a/python/pyspark/pandas/utils.py +++ b/python/pyspark/pandas/utils.py @@ -42,7 +42,7 @@ from pyspark.sql import functions as F, Column, DataFrame as PySparkDataFrame, SparkSession from pyspark.sql.types import DoubleType -from pyspark.sql.utils import is_remote, get_dataframe_class +from pyspark.sql.utils import is_remote from pyspark.errors import PySparkTypeError from pyspark import pandas as ps # noqa: F401 from pyspark.pandas._typing import ( @@ -915,8 +915,7 @@ def verify_temp_column_name( ) column_name = column_name_or_label - SparkDataFrame = get_dataframe_class() - assert isinstance(df, SparkDataFrame), type(df) + assert isinstance(df, PySparkDataFrame), type(df) assert ( column_name not in df.columns ), "The given column name `{}` already exists in the Spark DataFrame: {}".format( diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index df0451fa1bd2c..98bc7a72c4aa1 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -422,17 +422,6 @@ def pyspark_column_op( return result.fillna(fillna) if fillna is not None else result -def get_dataframe_class() -> Type["DataFrame"]: - from pyspark.sql.dataframe import DataFrame as PySparkDataFrame - - if is_remote(): - from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame - - return ConnectDataFrame - else: - return PySparkDataFrame - - def get_window_class() -> Type["Window"]: from pyspark.sql.window import Window as PySparkWindow diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 8a7aa405e4ac7..fa58b7286fe88 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -829,12 +829,7 @@ def assertDataFrameEqual( actual, expected, almost=True, rtol=rtol, atol=atol, check_row_order=checkRowOrder ) - from pyspark.sql.utils import get_dataframe_class - - # if is_remote(), allow Connect DataFrame - SparkDataFrame = get_dataframe_class() - - if not isinstance(actual, (DataFrame, SparkDataFrame, list)): + if not isinstance(actual, (DataFrame, list)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={ @@ -843,7 +838,7 @@ def assertDataFrameEqual( "actual_type": type(actual), }, ) - elif not isinstance(expected, (DataFrame, SparkDataFrame, list)): + elif not isinstance(expected, (DataFrame, list)): raise PySparkAssertionError( error_class="INVALID_TYPE_DF_EQUALITY_ARG", message_parameters={