-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[data] Handle nullable fields in schema across blocks for parquet files #48478
Changes from 3 commits
ae1c6c7
a948bc7
bc2dd0d
a086929
4e7418b
2086c95
d041687
d35d53a
a54e1c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import logging | ||
import posixpath | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional | ||
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional | ||
|
||
from ray.data._internal.execution.interfaces import TaskContext | ||
from ray.data._internal.util import call_with_retry | ||
|
@@ -75,10 +75,12 @@ def write( | |
|
||
def write_blocks_to_path(): | ||
with self.open_output_stream(write_path) as file: | ||
schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema | ||
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks] | ||
schema = self._try_merge_nullable_fields(tables) | ||
with pq.ParquetWriter(file, schema, **write_kwargs) as writer: | ||
for block in blocks: | ||
table = BlockAccessor.for_block(block).to_arrow() | ||
for table in tables: | ||
if not table.schema.equals(schema): | ||
table = table.cast(schema) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if we don't explicitly cast the tables? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The table would still have a mismatch schema. i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. Wasn't sure if PyArrow would implicitly cast tables to match the specified schema under-the-hood There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it doesn't do the casting because there's check on the schema equality here: https://github.com/apache/arrow/blob/main/python/pyarrow/parquet/core.py#L1110-L1114 |
||
writer.write_table(table) | ||
|
||
logger.debug(f"Writing {write_path} file.") | ||
|
@@ -93,3 +95,70 @@ def write_blocks_to_path(): | |
@property | ||
def num_rows_per_write(self) -> Optional[int]: | ||
return self.num_rows_per_file | ||
|
||
def _try_merge_nullable_fields( | ||
self, tables: List["pyarrow.Table"] | ||
) -> "pyarrow.lib.Schema": | ||
""" | ||
Merge the nullable fields of the list of tables from multiple blocks. | ||
|
||
If blocks's schema differ only by nullable status on a field, | ||
we will make a "relaxed" schema that's compatible. | ||
|
||
NOTE that this function only merges on nullable fields, not | ||
on anything else. | ||
|
||
Raises: | ||
ValueError: If the schemas differ on anything other than nullable fields. | ||
|
||
Returns: | ||
The merged schema. | ||
""" | ||
merged_schema = tables[0].schema | ||
import pyarrow | ||
|
||
for table in tables[1:]: | ||
table_schema = table.schema | ||
if merged_schema.equals(table_schema): | ||
continue | ||
|
||
# Schema mismatch found. If fields only differ by nullable status, | ||
# we can continue. | ||
n_merged_schema = len(merged_schema.names) | ||
n_table_schema = len(table_schema.names) | ||
if n_merged_schema != n_table_schema: | ||
raise ValueError( | ||
f"Schema mismatch found: {merged_schema} vs {table_schema}" | ||
) | ||
|
||
for field_idx in range(n_merged_schema): | ||
field = merged_schema.field(field_idx) | ||
table_field = table_schema.field(field_idx) | ||
|
||
if field.equals(table_field): | ||
continue | ||
|
||
if field.name != table_field.name: | ||
raise ValueError( | ||
f"Schema mismatch found: {merged_schema} vs {table_schema}" | ||
) | ||
|
||
# Check if fields only differ by nullable status. | ||
if field.type == pyarrow.null() and table_field.nullable: | ||
merged_schema = merged_schema.set( | ||
field_idx, field.with_type(table_field.type) | ||
) | ||
|
||
if table_field.type == pyarrow.null() and field.nullable: | ||
# Make the table schema nullable on the field. | ||
table_schema = table_schema.set( | ||
field_idx, table_field.with_type(field.type) | ||
) | ||
|
||
# This makes sure we are only merging on nullable fields. | ||
if not merged_schema.equals(table_schema): | ||
raise ValueError( | ||
f"Schema mismatch found: {merged_schema} vs {table_schema}" | ||
) | ||
|
||
return merged_schema |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than introducing a new method, we could extend the existing
unify_schemas
function:ray/python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Line 54 in 23cc23b