Skip to content

Commit

Permalink
PyArrow: Don't enforce the schema when reading/writing (apache#902)
Browse files Browse the repository at this point in the history
* PyArrow: Don't enforce the schema

PyIceberg struggled with the different type of arrow, such as
the `string` and `large_string`. They represent the same, but are
different under the hood.

My take is that we should hide these kind of details from the user
as much as possible. Now we went down the road of passing in the
Iceberg schema into Arrow, but when doing this, Iceberg has to
decide if it is a large or non-large type.

This PR removes passing down the schema in order to let Arrow decide
unless:

 - The type should be evolved
 - In case of re-ordering, we reorder the original types

* WIP

* Reuse Table schema

* Make linter happy

* Squash some bugs

* Thanks Sung!

Co-authored-by: Sung Yun <107272191+syun64@users.noreply.github.com>

* Moar code moar bugs

* Remove the variables wrt file sizes

* Linting

* Go with large ones for now

* Missed one there!

---------

Co-authored-by: Sung Yun <107272191+syun64@users.noreply.github.com>
  • Loading branch information
Fokko and sungwy authored Jul 11, 2024
1 parent 8f47dfd commit 1b9b884
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 58 deletions.
73 changes: 45 additions & 28 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,8 +1047,10 @@ def _task_to_record_batches(

fragment_scanner = ds.Scanner.from_fragment(
fragment=fragment,
# We always use large types in memory as it uses larger offsets
# That can chunk more row values into the buffers
# With PyArrow 16.0.0 there is an issue with casting record-batches:
# https://github.com/apache/arrow/issues/41884
# https://github.com/apache/arrow/issues/43183
# Would be good to remove this later on
schema=_pyarrow_schema_ensure_large_types(physical_schema),
# This will push down the query to Arrow.
# But in case there are positional deletes, we have to apply them first
Expand Down Expand Up @@ -1084,11 +1086,17 @@ def _task_to_table(
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
) -> pa.Table:
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
) -> Optional[pa.Table]:
batches = list(
_task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
)
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

if len(batches) > 0:
return pa.Table.from_batches(batches)
else:
return None


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1192,7 +1200,7 @@ def project_table(
if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))

result = pa.concat_tables(tables)
result = pa.concat_tables(tables, promote_options="permissive")

if limit is not None:
return result.slice(0, limit)
Expand Down Expand Up @@ -1271,54 +1279,62 @@ def project_batches(


def to_requested_schema(
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, downcast_ns_timestamp_to_us: bool = False
requested_schema: Schema,
file_schema: Schema,
batch: pa.RecordBatch,
downcast_ns_timestamp_to_us: bool = False,
include_field_ids: bool = False,
) -> pa.RecordBatch:
# We could re-use some of these visitors
struct_array = visit_with_partner(
requested_schema, batch, ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us), ArrowAccessor(file_schema)
requested_schema,
batch,
ArrowProjectionVisitor(file_schema, downcast_ns_timestamp_to_us, include_field_ids),
ArrowAccessor(file_schema),
)

arrays = []
fields = []
for pos, field in enumerate(requested_schema.fields):
array = struct_array.field(pos)
arrays.append(array)
fields.append(pa.field(field.name, array.type, field.optional))
return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields))
return pa.RecordBatch.from_struct_array(struct_array)


class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema
_include_field_ids: bool

def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False):
def __init__(self, file_schema: Schema, downcast_ns_timestamp_to_us: bool = False, include_field_ids: bool = False) -> None:
self.file_schema = file_schema
self._include_field_ids = include_field_ids
self.downcast_ns_timestamp_to_us = downcast_ns_timestamp_to_us

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)

if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
# if file_field and field_type (e.g. String) are the same
# but the pyarrow type of the array is different from the expected type
# (e.g. string vs larger_string), we want to cast the array to the larger type
safe = True
return values.cast(
schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=self._include_field_ids)
)
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=self._include_field_ids)) != values.type:
# Downcasting of nanoseconds to microseconds
if (
pa.types.is_timestamp(target_type)
and target_type.unit == "us"
and pa.types.is_timestamp(values.type)
and values.type.unit == "ns"
):
safe = False
return values.cast(target_type, safe=safe)
return values.cast(target_type, safe=False)
return values

def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field:
metadata = {}
if field.doc:
metadata[PYARROW_FIELD_DOC_KEY] = field.doc
if self._include_field_ids:
metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id)

return pa.field(
name=field.name,
type=arrow_type,
nullable=field.optional,
metadata={DOC: field.doc} if field.doc is not None else None,
metadata=metadata,
)

def schema(self, schema: Schema, schema_partner: Optional[pa.Array], struct_result: Optional[pa.Array]) -> Optional[pa.Array]:
Expand Down Expand Up @@ -1960,14 +1976,15 @@ def write_parquet(task: WriteTask) -> DataFile:
file_schema=table_schema,
batch=batch,
downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us,
include_field_ids=True,
)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
with pq.ParquetWriter(fos, schema=arrow_table.schema, **parquet_writer_kwargs) as writer:
writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
Expand Down
3 changes: 2 additions & 1 deletion pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,8 +2053,9 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:

from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow

target_schema = schema_to_pyarrow(self.projection())
return pa.RecordBatchReader.from_batches(
schema_to_pyarrow(self.projection()),
target_schema,
project_batches(
self.plan_files(),
self.table_metadata,
Expand Down
80 changes: 72 additions & 8 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os
from datetime import date
from typing import Iterator, Optional
from typing import Iterator

import pyarrow as pa
import pyarrow.parquet as pq
Expand All @@ -28,7 +28,8 @@

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.io import FileIO
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table import Table
from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform
Expand Down Expand Up @@ -107,23 +108,32 @@
)


def _write_parquet(io: FileIO, file_path: str, arrow_schema: pa.Schema, arrow_table: pa.Table) -> None:
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=arrow_schema) as writer:
writer.write_table(arrow_table)


def _create_table(
session_catalog: Catalog, identifier: str, format_version: int, partition_spec: Optional[PartitionSpec] = None
session_catalog: Catalog,
identifier: str,
format_version: int,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
schema: Schema = TABLE_SCHEMA,
) -> Table:
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(
return session_catalog.create_table(
identifier=identifier,
schema=TABLE_SCHEMA,
schema=schema,
properties={"format-version": str(format_version)},
partition_spec=partition_spec if partition_spec else PartitionSpec(),
partition_spec=partition_spec,
)

return tbl


@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")])
def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
Expand Down Expand Up @@ -454,6 +464,60 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat


@pytest.mark.integration
def test_add_files_with_large_and_regular_schema(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.unpartitioned_with_large_types{format_version}"

iceberg_schema = Schema(NestedField(1, "foo", StringType(), required=True))
arrow_schema = pa.schema([
pa.field("foo", pa.string(), nullable=False),
])
arrow_schema_large = pa.schema([
pa.field("foo", pa.large_string(), nullable=False),
])

tbl = _create_table(session_catalog, identifier, format_version, schema=iceberg_schema)

file_path = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-0.parquet"
_write_parquet(
tbl.io,
file_path,
arrow_schema,
pa.Table.from_pylist(
[
{
"foo": "normal",
}
],
schema=arrow_schema,
),
)

tbl.add_files([file_path])

table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema_large

file_path_large = f"s3://warehouse/default/unpartitioned_with_large_types/v{format_version}/test-1.parquet"
_write_parquet(
tbl.io,
file_path_large,
arrow_schema_large,
pa.Table.from_pylist(
[
{
"foo": "normal",
}
],
schema=arrow_schema_large,
),
)

tbl.add_files([file_path_large])

table_schema = tbl.scan().to_arrow().schema
assert table_schema == arrow_schema_large


def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_partitioned_table_positional_deletes_sequence_number(spark: SparkSessio
assert snapshots[2].summary == Summary(
Operation.OVERWRITE,
**{
"added-files-size": "1145",
"added-files-size": snapshots[2].summary["total-files-size"],
"added-data-files": "1",
"added-records": "2",
"changed-partition-count": "1",
Expand Down
9 changes: 6 additions & 3 deletions tests/integration/test_inspect_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,25 @@ def test_inspect_snapshots(
for manifest_list in df["manifest_list"]:
assert manifest_list.as_py().startswith("s3://")

file_size = int(next(value for key, value in df["summary"][0].as_py() if key == "added-files-size"))
assert file_size > 0

# Append
assert df["summary"][0].as_py() == [
("added-files-size", "5459"),
("added-files-size", str(file_size)),
("added-data-files", "1"),
("added-records", "3"),
("total-data-files", "1"),
("total-delete-files", "0"),
("total-records", "3"),
("total-files-size", "5459"),
("total-files-size", str(file_size)),
("total-position-deletes", "0"),
("total-equality-deletes", "0"),
]

# Delete
assert df["summary"][1].as_py() == [
("removed-files-size", "5459"),
("removed-files-size", str(file_size)),
("deleted-data-files", "1"),
("deleted-records", "3"),
("total-data-files", "0"),
Expand Down
12 changes: 8 additions & 4 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,28 +252,32 @@ def test_summaries_with_null(spark: SparkSession, session_catalog: Catalog, arro
assert operations == ["append", "append"]

summaries = [row.summary for row in rows]

file_size = int(summaries[0]["added-files-size"])
assert file_size > 0

assert summaries[0] == {
"changed-partition-count": "3",
"added-data-files": "3",
"added-files-size": "15029",
"added-files-size": str(file_size),
"added-records": "3",
"total-data-files": "3",
"total-delete-files": "0",
"total-equality-deletes": "0",
"total-files-size": "15029",
"total-files-size": str(file_size),
"total-position-deletes": "0",
"total-records": "3",
}

assert summaries[1] == {
"changed-partition-count": "3",
"added-data-files": "3",
"added-files-size": "15029",
"added-files-size": str(file_size),
"added-records": "3",
"total-data-files": "6",
"total-delete-files": "0",
"total-equality-deletes": "0",
"total-files-size": "30058",
"total-files-size": str(file_size * 2),
"total-position-deletes": "0",
"total-records": "6",
}
Expand Down
Loading

0 comments on commit 1b9b884

Please sign in to comment.