Skip to content

Commit

Permalink
[data] Handle nullable fields in schema across blocks for parquet fil…
Browse files Browse the repository at this point in the history
…es (#48478)

<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

When writing blocks to parquet, there might be blocks with fields that
differ ONLY in nullability - by default, this would be rejected since
some blocks might have a different schema than the ParquetWriter.
However, we could potentially allow it to happen by tweaking the schema.

This PR goes through all blocks before writing them to parquet, and
merge schemas that differ only in nullability of the fields.
It also casts the table to the newly merged schema so that the write
could happen.

<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

Closes #48102

---------

Signed-off-by: rickyx <rickyx@anyscale.com>
  • Loading branch information
rickyyx authored Nov 14, 2024
1 parent bcee207 commit 138e59a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
14 changes: 9 additions & 5 deletions python/ray/data/_internal/datasource/parquet_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def write(
blocks: Iterable[Block],
ctx: TaskContext,
) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

blocks = list(blocks)
Expand All @@ -72,16 +73,19 @@ def write(
write_kwargs = _resolve_kwargs(
self.arrow_parquet_args_fn, **self.arrow_parquet_args
)
schema = write_kwargs.pop("schema", None)
if schema is None:
schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema
user_schema = write_kwargs.pop("schema", None)

def write_blocks_to_path():
with self.open_output_stream(write_path) as file:
tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks]
with pq.ParquetWriter(file, schema, **write_kwargs) as writer:
if user_schema is None:
output_schema = pa.unify_schemas([table.schema for table in tables])
else:
output_schema = user_schema

with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer:
for table in tables:
table = table.cast(schema)
table = table.cast(output_schema)
writer.write_table(table)

logger.debug(f"Writing {write_path} file.")
Expand Down
24 changes: 24 additions & 0 deletions python/ray/data/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,30 @@ def test_write_with_schema(ray_start_regular_shared, tmp_path):
assert pq.read_table(tmp_path).schema == schema


@pytest.mark.parametrize(
"row_data",
[
[{"a": 1, "b": None}, {"a": 1, "b": 2}],
[{"a": None, "b": 3}, {"a": 1, "b": 2}],
[{"a": None, "b": 1}, {"a": 1, "b": None}],
],
ids=["row1_b_null", "row1_a_null", "row_each_null"],
)
def test_write_auto_infer_nullable_fields(
tmp_path, ray_start_regular_shared, row_data, restore_data_context
):
"""
Test that when writing multiple blocks, we can automatically infer nullable
fields.
"""
ctx = DataContext.get_current()
# So that we force multiple blocks on mapping.
ctx.target_max_block_size = 1
ds = ray.data.range(len(row_data)).map(lambda row: row_data[row["id"]])
# So we force writing to a single file.
ds.write_parquet(tmp_path, num_rows_per_file=2)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 138e59a

Please sign in to comment.