From 2c2b68383921e955d1674bc6d9ea373157f60c5b Mon Sep 17 00:00:00 2001 From: Takuya Ueshin Date: Thu, 27 Jun 2024 16:49:49 -0700 Subject: [PATCH] Fix. --- python/pyspark/pandas/internal.py | 263 +++++++++++++----------------- 1 file changed, 117 insertions(+), 146 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index c5fef3b138254..c856f79ded957 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -631,39 +631,34 @@ def __init__( ] spark_frame = spark_frame.select(data_spark_columns) - assert not any(SPARK_INDEX_NAME_PATTERN.match(name) for name in spark_frame.columns), ( + base_schema = spark_frame.schema + + assert not any(SPARK_INDEX_NAME_PATTERN.match(name) for name in base_schema.names), ( "Index columns should not appear in columns of the Spark DataFrame. Avoid " "index column names [%s]." % SPARK_INDEX_NAME_PATTERN ) # Create default index. spark_frame = InternalFrame.attach_default_index(spark_frame) + base_schema = StructType( + [StructField(SPARK_DEFAULT_INDEX_NAME, LongType(), nullable=False)] + + base_schema.fields, + ) index_spark_columns = [scol_for(spark_frame, SPARK_DEFAULT_INDEX_NAME)] - index_fields = [ - InternalField.from_struct_field( - StructField(SPARK_DEFAULT_INDEX_NAME, LongType(), nullable=False) - ) - ] + index_fields = [InternalField.from_struct_field(base_schema.fields[0])] if data_spark_columns is not None: - data_struct_fields = [ - field - for field in spark_frame.schema.fields - if field.name != SPARK_DEFAULT_INDEX_NAME - ] data_spark_columns = [ - scol_for(spark_frame, field.name) for field in data_struct_fields + scol_for(spark_frame, field.name) for field in base_schema.fields[1:] ] - if data_fields is not None: - data_fields = [ - field.copy( - name=name_like_string(struct_field.name), - ) - for field, struct_field in zip(data_fields, data_struct_fields) - ] + data_fields = [ + InternalField.from_struct_field(field) for field in base_schema.fields[1:] + ] + else: + base_schema = spark_frame.schema - if NATURAL_ORDER_COLUMN_NAME not in spark_frame.columns: + if NATURAL_ORDER_COLUMN_NAME not in base_schema.names: spark_frame = spark_frame.withColumn( NATURAL_ORDER_COLUMN_NAME, F.monotonically_increasing_id() ) @@ -682,7 +677,7 @@ def __init__( if data_spark_columns is None: data_spark_columns = [ scol_for(spark_frame, col) - for col in spark_frame.columns + for col in base_schema.names if all( not spark_column_equals(scol_for(spark_frame, col), index_scol) for index_scol in index_spark_columns @@ -695,110 +690,13 @@ def __init__( self._data_spark_columns: List[PySparkColumn] = data_spark_columns # fields - if index_fields is None: - index_fields = [None] * len(index_spark_columns) - if data_fields is None: - data_fields = [None] * len(data_spark_columns) - - assert len(index_spark_columns) == len(index_fields), ( - len(index_spark_columns), - len(index_fields), - ) - assert len(data_spark_columns) == len(data_fields), ( - len(data_spark_columns), - len(data_fields), - ) - - if any(field is None or field.struct_field is None for field in index_fields) and any( - field is None or field.struct_field is None for field in data_fields - ): - schema = spark_frame.select(index_spark_columns + data_spark_columns).schema - fields = [ - InternalField.from_struct_field(struct_field) - if field is None - else InternalField(field.dtype, struct_field) - if field.struct_field is None - else field - for field, struct_field in zip(index_fields + data_fields, schema.fields) - ] - index_fields = fields[: len(index_spark_columns)] - data_fields = fields[len(index_spark_columns) :] - elif any(field is None or field.struct_field is None for field in index_fields): - schema = spark_frame.select(index_spark_columns).schema - index_fields = [ - InternalField.from_struct_field(struct_field) - if field is None - else InternalField(field.dtype, struct_field) - if field.struct_field is None - else field - for field, struct_field in zip(index_fields, schema.fields) - ] - elif any(field is None or field.struct_field is None for field in data_fields): - schema = spark_frame.select(data_spark_columns).schema - data_fields = [ - InternalField.from_struct_field(struct_field) - if field is None - else InternalField(field.dtype, struct_field) - if field.struct_field is None - else field - for field, struct_field in zip(data_fields, schema.fields) - ] - - assert all( - isinstance(ops.dtype, Dtype.__args__) # type: ignore[attr-defined] - and ( - ops.dtype == np.dtype("object") - or as_spark_type(ops.dtype, raise_error=False) is not None - ) - for ops in index_fields - ), index_fields - - if is_testing(): - struct_fields = spark_frame.select(index_spark_columns).schema.fields - if is_remote(): - # TODO(SPARK-42965): For some reason, the metadata of StructField is different - # in a few tests when using Spark Connect. However, the function works properly. - # Therefore, we temporarily perform Spark Connect tests by excluding metadata - # until the issue is resolved. - assert all( - _drop_metadata(index_field.struct_field) == _drop_metadata(struct_field) - for index_field, struct_field in zip(index_fields, struct_fields) - ), (index_fields, struct_fields) - else: - assert all( - index_field.struct_field == struct_field - for index_field, struct_field in zip(index_fields, struct_fields) - ), (index_fields, struct_fields) - - self._index_fields: List[InternalField] = index_fields - - assert all( - isinstance(ops.dtype, Dtype.__args__) # type: ignore[attr-defined] - and ( - ops.dtype == np.dtype("object") - or as_spark_type(ops.dtype, raise_error=False) is not None - ) - for ops in data_fields - ), data_fields + if index_fields is not None: + self._check_fields(index_fields, index_spark_columns) + self._index_fields: Optional[List[InternalField]] = index_fields - if is_testing(): - struct_fields = spark_frame.select(data_spark_columns).schema.fields - if is_remote(): - # TODO(SPARK-42965): For some reason, the metadata of StructField is different - # in a few tests when using Spark Connect. However, the function works properly. - # Therefore, we temporarily perform Spark Connect tests by excluding metadata - # until the issue is resolved. - assert all( - _drop_metadata(data_field.struct_field) == _drop_metadata(struct_field) - for data_field, struct_field in zip(data_fields, struct_fields) - ), (data_fields, struct_fields) - else: - assert all( - data_field.struct_field == struct_field - for data_field, struct_field in zip(data_fields, struct_fields) - ), (data_fields, struct_fields) - - self._data_fields: List[InternalField] = data_fields + if data_fields is not None: + self._check_fields(data_fields, data_spark_columns) + self._data_fields: Optional[List[InternalField]] = data_fields # index_names if not index_names: @@ -815,9 +713,7 @@ def __init__( self._index_names: List[Optional[Label]] = index_names # column_labels - if column_labels is None: - column_labels = [(col,) for col in spark_frame.select(self._data_spark_columns).columns] - else: + if column_labels is not None: assert len(column_labels) == len(self._data_spark_columns), ( len(column_labels), len(self._data_spark_columns), @@ -832,13 +728,11 @@ def __init__( ), column_labels assert len(set(len(label) for label in column_labels)) <= 1, column_labels - self._column_labels: List[Label] = column_labels + self._column_labels: Optional[List[Label]] = column_labels # column_label_names - if column_label_names is None: - column_label_names = [None] * column_labels_level(self._column_labels) - else: - if len(self._column_labels) > 0: + if column_label_names is not None: + if self._column_labels is not None and len(self._column_labels) > 0: assert len(column_label_names) == column_labels_level(self._column_labels), ( len(column_label_names), column_labels_level(self._column_labels), @@ -850,7 +744,51 @@ def __init__( for column_label_name in column_label_names ), column_label_names - self._column_label_names: List[Optional[Label]] = column_label_names + self._column_label_names: Optional[List[Optional[Label]]] = column_label_names + + def _check_fields( + self, fields: List[InternalField], spark_columns: List[PySparkColumn] + ) -> None: + # The `fields` should have the same length as columns + assert len(fields) == len(spark_columns), ( + len(fields), + len(spark_columns), + ) + + # The `field.dtype` should be one of supported types + assert all( + field is None + or ( + isinstance(field.dtype, Dtype.__args__) # type: ignore[attr-defined] + and ( + field.dtype == np.dtype("object") + or as_spark_type(field.dtype, raise_error=False) is not None + ) + ) + for field in fields + ), fields + + if is_testing(): + # The `field.struct_field` should be the same as the corresponding column's field + struct_fields = self.spark_frame.select(spark_columns).schema.fields + if is_remote(): + # TODO(SPARK-42965): For some reason, the metadata of StructField is different + # in a few tests when using Spark Connect. However, the function works properly. + # Therefore, we temporarily perform Spark Connect tests by excluding metadata + # until the issue is resolved. + assert all( + field is None + or field.struct_field is None + or _drop_metadata(field.struct_field) == _drop_metadata(struct_field) + for field, struct_field in zip(fields, struct_fields) + ), (fields, struct_fields) + else: + assert all( + field is None + or field.struct_field is None + or field.struct_field == struct_field + for field, struct_field in zip(fields, struct_fields) + ), (fields, struct_fields) @staticmethod def attach_default_index( @@ -1052,29 +990,62 @@ def index_level(self) -> int: """Return the level of the index.""" return len(self._index_names) - @property + @lazy_property def column_labels(self) -> List[Label]: """Return the managed column index.""" + if self._column_labels is None: + self._column_labels = [(field.name,) for field in self.data_fields] return self._column_labels @lazy_property def column_labels_level(self) -> int: """Return the level of the column index.""" - return len(self._column_label_names) + return len(self.column_label_names) - @property + @lazy_property def column_label_names(self) -> List[Optional[Label]]: """Return names of the index levels.""" + if self._column_label_names is None: + self._column_label_names = [None] * column_labels_level(self.column_labels) return self._column_label_names - @property + def _initialize_fields(self): + if self._index_fields is None: + self._index_fields = [None] * len(self._index_spark_columns) + if self._data_fields is None: + self._data_fields = [None] * len(self._data_spark_columns) + + if any( + field is None or field.struct_field is None + for field in self._index_fields + self._data_fields + ): + schema = self.spark_frame.select( + self.index_spark_columns + self.data_spark_columns + ).schema + fields = [ + InternalField.from_struct_field(struct_field) + if field is None + else InternalField(field.dtype, struct_field) + if field.struct_field is None + else field + for field, struct_field in zip( + self._index_fields + self._data_fields, schema.fields + ) + ] + self._check_fields(fields, self.index_spark_columns + self.data_spark_columns) + self._index_fields = fields[: len(self.index_spark_columns)] + self._data_fields = fields[len(self.index_spark_columns) :] + + @lazy_property def index_fields(self) -> List[InternalField]: """Return InternalFields for the managed index columns.""" + self._initialize_fields() return self._index_fields - @property + @lazy_property def data_fields(self) -> List[InternalField]: """Return InternalFields for the managed columns.""" + self._initialize_fields() return self._data_fields @lazy_property @@ -1450,21 +1421,21 @@ def copy( :return: the copied immutable InternalFrame. """ if spark_frame is _NoValue: - spark_frame = self.spark_frame + spark_frame = self._sdf if index_spark_columns is _NoValue: - index_spark_columns = self.index_spark_columns + index_spark_columns = self._index_spark_columns if index_names is _NoValue: - index_names = self.index_names + index_names = self._index_names if index_fields is _NoValue: - index_fields = self.index_fields + index_fields = self._index_fields if column_labels is _NoValue: - column_labels = self.column_labels + column_labels = self._column_labels if data_spark_columns is _NoValue: - data_spark_columns = self.data_spark_columns + data_spark_columns = self._data_spark_columns if data_fields is _NoValue: - data_fields = self.data_fields + data_fields = self._data_fields if column_label_names is _NoValue: - column_label_names = self.column_label_names + column_label_names = self._column_label_names return InternalFrame( spark_frame=cast(PySparkDataFrame, spark_frame), index_spark_columns=cast(List[PySparkColumn], index_spark_columns),