Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[data] Handle nullable fields in schema across blocks for parquet files #48478

Merged
merged 9 commits into from
Nov 14, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import posixpath
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional

from ray.data._internal.execution.interfaces import TaskContext
from ray.data._internal.util import call_with_retry
Expand Down Expand Up @@ -75,10 +75,12 @@ def write(

def write_blocks_to_path():
with self.open_output_stream(write_path) as file:
schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
schema = self._try_merge_nullable_fields(tables)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than introducing a new method, we could extend the existing unify_schemas function:

with pq.ParquetWriter(file, schema, **write_kwargs) as writer:
for block in blocks:
table = BlockAccessor.for_block(block).to_arrow()
for table in tables:
if not table.schema.equals(schema):
table = table.cast(schema)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if we don't explicitly cast the tables?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The table would still have a mismatch schema. i.e.
table.schema.equals(schema) in this case would still be false.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. Wasn't sure if PyArrow would implicitly cast tables to match the specified schema under-the-hood

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it doesn't do the casting because there's check on the schema equality here:

https://github.com/apache/arrow/blob/main/python/pyarrow/parquet/core.py#L1110-L1114

writer.write_table(table)

logger.debug(f"Writing {write_path} file.")
Expand All @@ -93,3 +95,70 @@ def write_blocks_to_path():
@property
def num_rows_per_write(self) -> Optional[int]:
return self.num_rows_per_file

def _try_merge_nullable_fields(
self, tables: List["pyarrow.Table"]
) -> "pyarrow.lib.Schema":
"""
Merge the nullable fields of the list of tables from multiple blocks.

If blocks's schema differ only by nullable status on a field,
we will make a "relaxed" schema that's compatible.

NOTE that this function only merges on nullable fields, not
on anything else.

Raises:
ValueError: If the schemas differ on anything other than nullable fields.

Returns:
The merged schema.
"""
merged_schema = tables[0].schema
import pyarrow

for table in tables[1:]:
table_schema = table.schema
if merged_schema.equals(table_schema):
continue

# Schema mismatch found. If fields only differ by nullable status,
# we can continue.
n_merged_schema = len(merged_schema.names)
n_table_schema = len(table_schema.names)
if n_merged_schema != n_table_schema:
raise ValueError(
f"Schema mismatch found: {merged_schema} vs {table_schema}"
)

for field_idx in range(n_merged_schema):
field = merged_schema.field(field_idx)
table_field = table_schema.field(field_idx)

if field.equals(table_field):
continue

if field.name != table_field.name:
raise ValueError(
f"Schema mismatch found: {merged_schema} vs {table_schema}"
)

# Check if fields only differ by nullable status.
if field.type == pyarrow.null() and table_field.nullable:
merged_schema = merged_schema.set(
field_idx, field.with_type(table_field.type)
)

if table_field.type == pyarrow.null() and field.nullable:
# Make the table schema nullable on the field.
table_schema = table_schema.set(
field_idx, table_field.with_type(field.type)
)

# This makes sure we are only merging on nullable fields.
if not merged_schema.equals(table_schema):
raise ValueError(
f"Schema mismatch found: {merged_schema} vs {table_schema}"
)

return merged_schema