Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Jun 27, 2024
1 parent 5035c20 commit 2c2b683
Showing 1 changed file with 117 additions and 146 deletions.
263 changes: 117 additions & 146 deletions python/pyspark/pandas/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,39 +631,34 @@ def __init__(
]
spark_frame = spark_frame.select(data_spark_columns)

assert not any(SPARK_INDEX_NAME_PATTERN.match(name) for name in spark_frame.columns), (
base_schema = spark_frame.schema

assert not any(SPARK_INDEX_NAME_PATTERN.match(name) for name in base_schema.names), (
"Index columns should not appear in columns of the Spark DataFrame. Avoid "
"index column names [%s]." % SPARK_INDEX_NAME_PATTERN
)

# Create default index.
spark_frame = InternalFrame.attach_default_index(spark_frame)
base_schema = StructType(
[StructField(SPARK_DEFAULT_INDEX_NAME, LongType(), nullable=False)]
+ base_schema.fields,
)
index_spark_columns = [scol_for(spark_frame, SPARK_DEFAULT_INDEX_NAME)]

index_fields = [
InternalField.from_struct_field(
StructField(SPARK_DEFAULT_INDEX_NAME, LongType(), nullable=False)
)
]
index_fields = [InternalField.from_struct_field(base_schema.fields[0])]

if data_spark_columns is not None:
data_struct_fields = [
field
for field in spark_frame.schema.fields
if field.name != SPARK_DEFAULT_INDEX_NAME
]
data_spark_columns = [
scol_for(spark_frame, field.name) for field in data_struct_fields
scol_for(spark_frame, field.name) for field in base_schema.fields[1:]
]
if data_fields is not None:
data_fields = [
field.copy(
name=name_like_string(struct_field.name),
)
for field, struct_field in zip(data_fields, data_struct_fields)
]
data_fields = [
InternalField.from_struct_field(field) for field in base_schema.fields[1:]
]
else:
base_schema = spark_frame.schema

if NATURAL_ORDER_COLUMN_NAME not in spark_frame.columns:
if NATURAL_ORDER_COLUMN_NAME not in base_schema.names:
spark_frame = spark_frame.withColumn(
NATURAL_ORDER_COLUMN_NAME, F.monotonically_increasing_id()
)
Expand All @@ -682,7 +677,7 @@ def __init__(
if data_spark_columns is None:
data_spark_columns = [
scol_for(spark_frame, col)
for col in spark_frame.columns
for col in base_schema.names
if all(
not spark_column_equals(scol_for(spark_frame, col), index_scol)
for index_scol in index_spark_columns
Expand All @@ -695,110 +690,13 @@ def __init__(
self._data_spark_columns: List[PySparkColumn] = data_spark_columns

# fields
if index_fields is None:
index_fields = [None] * len(index_spark_columns)
if data_fields is None:
data_fields = [None] * len(data_spark_columns)

assert len(index_spark_columns) == len(index_fields), (
len(index_spark_columns),
len(index_fields),
)
assert len(data_spark_columns) == len(data_fields), (
len(data_spark_columns),
len(data_fields),
)

if any(field is None or field.struct_field is None for field in index_fields) and any(
field is None or field.struct_field is None for field in data_fields
):
schema = spark_frame.select(index_spark_columns + data_spark_columns).schema
fields = [
InternalField.from_struct_field(struct_field)
if field is None
else InternalField(field.dtype, struct_field)
if field.struct_field is None
else field
for field, struct_field in zip(index_fields + data_fields, schema.fields)
]
index_fields = fields[: len(index_spark_columns)]
data_fields = fields[len(index_spark_columns) :]
elif any(field is None or field.struct_field is None for field in index_fields):
schema = spark_frame.select(index_spark_columns).schema
index_fields = [
InternalField.from_struct_field(struct_field)
if field is None
else InternalField(field.dtype, struct_field)
if field.struct_field is None
else field
for field, struct_field in zip(index_fields, schema.fields)
]
elif any(field is None or field.struct_field is None for field in data_fields):
schema = spark_frame.select(data_spark_columns).schema
data_fields = [
InternalField.from_struct_field(struct_field)
if field is None
else InternalField(field.dtype, struct_field)
if field.struct_field is None
else field
for field, struct_field in zip(data_fields, schema.fields)
]

assert all(
isinstance(ops.dtype, Dtype.__args__) # type: ignore[attr-defined]
and (
ops.dtype == np.dtype("object")
or as_spark_type(ops.dtype, raise_error=False) is not None
)
for ops in index_fields
), index_fields

if is_testing():
struct_fields = spark_frame.select(index_spark_columns).schema.fields
if is_remote():
# TODO(SPARK-42965): For some reason, the metadata of StructField is different
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
assert all(
_drop_metadata(index_field.struct_field) == _drop_metadata(struct_field)
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)
else:
assert all(
index_field.struct_field == struct_field
for index_field, struct_field in zip(index_fields, struct_fields)
), (index_fields, struct_fields)

self._index_fields: List[InternalField] = index_fields

assert all(
isinstance(ops.dtype, Dtype.__args__) # type: ignore[attr-defined]
and (
ops.dtype == np.dtype("object")
or as_spark_type(ops.dtype, raise_error=False) is not None
)
for ops in data_fields
), data_fields
if index_fields is not None:
self._check_fields(index_fields, index_spark_columns)
self._index_fields: Optional[List[InternalField]] = index_fields

if is_testing():
struct_fields = spark_frame.select(data_spark_columns).schema.fields
if is_remote():
# TODO(SPARK-42965): For some reason, the metadata of StructField is different
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
assert all(
_drop_metadata(data_field.struct_field) == _drop_metadata(struct_field)
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)
else:
assert all(
data_field.struct_field == struct_field
for data_field, struct_field in zip(data_fields, struct_fields)
), (data_fields, struct_fields)

self._data_fields: List[InternalField] = data_fields
if data_fields is not None:
self._check_fields(data_fields, data_spark_columns)
self._data_fields: Optional[List[InternalField]] = data_fields

# index_names
if not index_names:
Expand All @@ -815,9 +713,7 @@ def __init__(
self._index_names: List[Optional[Label]] = index_names

# column_labels
if column_labels is None:
column_labels = [(col,) for col in spark_frame.select(self._data_spark_columns).columns]
else:
if column_labels is not None:
assert len(column_labels) == len(self._data_spark_columns), (
len(column_labels),
len(self._data_spark_columns),
Expand All @@ -832,13 +728,11 @@ def __init__(
), column_labels
assert len(set(len(label) for label in column_labels)) <= 1, column_labels

self._column_labels: List[Label] = column_labels
self._column_labels: Optional[List[Label]] = column_labels

# column_label_names
if column_label_names is None:
column_label_names = [None] * column_labels_level(self._column_labels)
else:
if len(self._column_labels) > 0:
if column_label_names is not None:
if self._column_labels is not None and len(self._column_labels) > 0:
assert len(column_label_names) == column_labels_level(self._column_labels), (
len(column_label_names),
column_labels_level(self._column_labels),
Expand All @@ -850,7 +744,51 @@ def __init__(
for column_label_name in column_label_names
), column_label_names

self._column_label_names: List[Optional[Label]] = column_label_names
self._column_label_names: Optional[List[Optional[Label]]] = column_label_names

def _check_fields(
self, fields: List[InternalField], spark_columns: List[PySparkColumn]
) -> None:
# The `fields` should have the same length as columns
assert len(fields) == len(spark_columns), (
len(fields),
len(spark_columns),
)

# The `field.dtype` should be one of supported types
assert all(
field is None
or (
isinstance(field.dtype, Dtype.__args__) # type: ignore[attr-defined]
and (
field.dtype == np.dtype("object")
or as_spark_type(field.dtype, raise_error=False) is not None
)
)
for field in fields
), fields

if is_testing():
# The `field.struct_field` should be the same as the corresponding column's field
struct_fields = self.spark_frame.select(spark_columns).schema.fields
if is_remote():
# TODO(SPARK-42965): For some reason, the metadata of StructField is different
# in a few tests when using Spark Connect. However, the function works properly.
# Therefore, we temporarily perform Spark Connect tests by excluding metadata
# until the issue is resolved.
assert all(
field is None
or field.struct_field is None
or _drop_metadata(field.struct_field) == _drop_metadata(struct_field)
for field, struct_field in zip(fields, struct_fields)
), (fields, struct_fields)
else:
assert all(
field is None
or field.struct_field is None
or field.struct_field == struct_field
for field, struct_field in zip(fields, struct_fields)
), (fields, struct_fields)

@staticmethod
def attach_default_index(
Expand Down Expand Up @@ -1052,29 +990,62 @@ def index_level(self) -> int:
"""Return the level of the index."""
return len(self._index_names)

@property
@lazy_property
def column_labels(self) -> List[Label]:
"""Return the managed column index."""
if self._column_labels is None:
self._column_labels = [(field.name,) for field in self.data_fields]
return self._column_labels

@lazy_property
def column_labels_level(self) -> int:
"""Return the level of the column index."""
return len(self._column_label_names)
return len(self.column_label_names)

@property
@lazy_property
def column_label_names(self) -> List[Optional[Label]]:
"""Return names of the index levels."""
if self._column_label_names is None:
self._column_label_names = [None] * column_labels_level(self.column_labels)
return self._column_label_names

@property
def _initialize_fields(self):
if self._index_fields is None:
self._index_fields = [None] * len(self._index_spark_columns)
if self._data_fields is None:
self._data_fields = [None] * len(self._data_spark_columns)

if any(
field is None or field.struct_field is None
for field in self._index_fields + self._data_fields
):
schema = self.spark_frame.select(
self.index_spark_columns + self.data_spark_columns
).schema
fields = [
InternalField.from_struct_field(struct_field)
if field is None
else InternalField(field.dtype, struct_field)
if field.struct_field is None
else field
for field, struct_field in zip(
self._index_fields + self._data_fields, schema.fields
)
]
self._check_fields(fields, self.index_spark_columns + self.data_spark_columns)
self._index_fields = fields[: len(self.index_spark_columns)]
self._data_fields = fields[len(self.index_spark_columns) :]

@lazy_property
def index_fields(self) -> List[InternalField]:
"""Return InternalFields for the managed index columns."""
self._initialize_fields()
return self._index_fields

@property
@lazy_property
def data_fields(self) -> List[InternalField]:
"""Return InternalFields for the managed columns."""
self._initialize_fields()
return self._data_fields

@lazy_property
Expand Down Expand Up @@ -1450,21 +1421,21 @@ def copy(
:return: the copied immutable InternalFrame.
"""
if spark_frame is _NoValue:
spark_frame = self.spark_frame
spark_frame = self._sdf
if index_spark_columns is _NoValue:
index_spark_columns = self.index_spark_columns
index_spark_columns = self._index_spark_columns
if index_names is _NoValue:
index_names = self.index_names
index_names = self._index_names
if index_fields is _NoValue:
index_fields = self.index_fields
index_fields = self._index_fields
if column_labels is _NoValue:
column_labels = self.column_labels
column_labels = self._column_labels
if data_spark_columns is _NoValue:
data_spark_columns = self.data_spark_columns
data_spark_columns = self._data_spark_columns
if data_fields is _NoValue:
data_fields = self.data_fields
data_fields = self._data_fields
if column_label_names is _NoValue:
column_label_names = self.column_label_names
column_label_names = self._column_label_names
return InternalFrame(
spark_frame=cast(PySparkDataFrame, spark_frame),
index_spark_columns=cast(List[PySparkColumn], index_spark_columns),
Expand Down

0 comments on commit 2c2b683

Please sign in to comment.