diff --git a/python/ray/data/_internal/datasource/parquet_datasink.py b/python/ray/data/_internal/datasource/parquet_datasink.py index 2b8edc11d531..0dfcebf4ba32 100644 --- a/python/ray/data/_internal/datasource/parquet_datasink.py +++ b/python/ray/data/_internal/datasource/parquet_datasink.py @@ -58,6 +58,7 @@ def write( blocks: Iterable[Block], ctx: TaskContext, ) -> None: + import pyarrow as pa import pyarrow.parquet as pq blocks = list(blocks) @@ -72,16 +73,19 @@ def write( write_kwargs = _resolve_kwargs( self.arrow_parquet_args_fn, **self.arrow_parquet_args ) - schema = write_kwargs.pop("schema", None) - if schema is None: - schema = BlockAccessor.for_block(blocks[0]).to_arrow().schema + user_schema = write_kwargs.pop("schema", None) def write_blocks_to_path(): with self.open_output_stream(write_path) as file: tables = [BlockAccessor.for_block(block).to_arrow() for block in blocks] - with pq.ParquetWriter(file, schema, **write_kwargs) as writer: + if user_schema is None: + output_schema = pa.unify_schemas([table.schema for table in tables]) + else: + output_schema = user_schema + + with pq.ParquetWriter(file, output_schema, **write_kwargs) as writer: for table in tables: - table = table.cast(schema) + table = table.cast(output_schema) writer.write_table(table) logger.debug(f"Writing {write_path} file.") diff --git a/python/ray/data/tests/test_parquet.py b/python/ray/data/tests/test_parquet.py index 739edb1ddd0b..1e493505edf5 100644 --- a/python/ray/data/tests/test_parquet.py +++ b/python/ray/data/tests/test_parquet.py @@ -1331,6 +1331,30 @@ def test_write_with_schema(ray_start_regular_shared, tmp_path): assert pq.read_table(tmp_path).schema == schema +@pytest.mark.parametrize( + "row_data", + [ + [{"a": 1, "b": None}, {"a": 1, "b": 2}], + [{"a": None, "b": 3}, {"a": 1, "b": 2}], + [{"a": None, "b": 1}, {"a": 1, "b": None}], + ], + ids=["row1_b_null", "row1_a_null", "row_each_null"], +) +def test_write_auto_infer_nullable_fields( + tmp_path, ray_start_regular_shared, row_data, restore_data_context +): + """ + Test that when writing multiple blocks, we can automatically infer nullable + fields. + """ + ctx = DataContext.get_current() + # So that we force multiple blocks on mapping. + ctx.target_max_block_size = 1 + ds = ray.data.range(len(row_data)).map(lambda row: row_data[row["id"]]) + # So we force writing to a single file. + ds.write_parquet(tmp_path, num_rows_per_file=2) + + if __name__ == "__main__": import sys