From 4951842f72d34cf578570ed7298d9c15cf75a590 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Thu, 20 Jun 2024 17:09:09 -0700 Subject: [PATCH] [FEAT] Add ability to specify snapshot_id for iceberg read (#2426) Closes #2400 --------- Co-authored-by: Jay Chia --- daft/iceberg/iceberg_scan.py | 15 +++++++++---- daft/io/_iceberg.py | 4 +++- .../iceberg/docker-compose/provision.py | 21 +++++++++++++++++++ tests/integration/iceberg/test_table_load.py | 12 +++++++++++ 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index b5f74bfe8b..73adb840c3 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -82,15 +82,22 @@ def iceberg_partition_spec_to_fields(iceberg_schema: IcebergSchema, spec: Iceber class IcebergScanOperator(ScanOperator): - def __init__(self, iceberg_table: Table, storage_config: StorageConfig) -> None: + def __init__(self, iceberg_table: Table, snapshot_id: int | None, storage_config: StorageConfig) -> None: super().__init__() self._table = iceberg_table + self._snapshot_id = snapshot_id self._storage_config = storage_config - iceberg_schema = iceberg_table.schema() + + iceberg_schema = ( + iceberg_table.schema() + if self._snapshot_id is None + else self._table.scan(snapshot_id=self._snapshot_id).projection() + ) arrow_schema = schema_to_pyarrow(iceberg_schema) self._field_id_mapping = visit(iceberg_schema, SchemaFieldIdMappingVisitor()) self._schema = Schema.from_pyarrow_schema(arrow_schema) - self._partition_keys = iceberg_partition_spec_to_fields(self._table.schema(), self._table.spec()) + + self._partition_keys = iceberg_partition_spec_to_fields(iceberg_schema, self._table.spec()) def schema(self) -> Schema: return self._schema @@ -129,7 +136,7 @@ def multiline_display(self) -> list[str]: def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: limit = pushdowns.limit - iceberg_tasks = self._table.scan(limit=limit).plan_files() + iceberg_tasks = self._table.scan(limit=limit, snapshot_id=self._snapshot_id).plan_files() limit_files = limit is not None and pushdowns.filters is None and pushdowns.partition_filters is None diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py index 17dd3fe488..1ee46da41e 100644 --- a/daft/io/_iceberg.py +++ b/daft/io/_iceberg.py @@ -85,6 +85,7 @@ def _convert_iceberg_file_io_properties_to_io_config(props: Dict[str, Any]) -> O @PublicAPI def read_iceberg( pyiceberg_table: "PyIcebergTable", + snapshot_id: Optional[int] = None, io_config: Optional["IOConfig"] = None, ) -> DataFrame: """Create a DataFrame from an Iceberg table @@ -106,6 +107,7 @@ def read_iceberg( Args: pyiceberg_table: Iceberg table created using the PyIceberg library + snapshot_id: Snapshot ID of the table to query io_config: A custom IOConfig to use when accessing Iceberg object storage data. Defaults to None. Returns: @@ -123,7 +125,7 @@ def read_iceberg( multithreaded_io = not context.get_context().is_ray_runner storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) - iceberg_operator = IcebergScanOperator(pyiceberg_table, storage_config=storage_config) + iceberg_operator = IcebergScanOperator(pyiceberg_table, snapshot_id=snapshot_id, storage_config=storage_config) handle = ScanOperatorHandle.from_python_scan_operator(iceberg_operator) builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle) diff --git a/tests/integration/iceberg/docker-compose/provision.py b/tests/integration/iceberg/docker-compose/provision.py index 692f714f5e..5d2c96cac9 100644 --- a/tests/integration/iceberg/docker-compose/provision.py +++ b/tests/integration/iceberg/docker-compose/provision.py @@ -406,3 +406,24 @@ ) spark.sql("INSERT INTO default.test_evolve_partitioning VALUES (CAST('2021-02-01' AS date))") + + +### +# Multi-snapshot table +### + +spark.sql(""" + CREATE OR REPLACE TABLE default.test_snapshotting + USING iceberg + AS SELECT + 1 AS idx, + float('NaN') AS col_numeric +UNION ALL SELECT + 2 AS idx, + null AS col_numeric +UNION ALL SELECT + 3 AS idx, + 1 AS col_numeric +""") + +spark.sql("INSERT INTO default.test_snapshotting VALUES (4, 1)") diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index 94b813b8f7..1fdc67b792 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -151,3 +151,15 @@ def test_daft_iceberg_table_read_partition_column_transformed(local_iceberg_cata iceberg_pandas = tab.scan().to_arrow().to_pandas() iceberg_pandas = iceberg_pandas[["number"]] assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +def test_daft_iceberg_table_read_table_snapshot(local_iceberg_catalog): + tab = local_iceberg_catalog.load_table("default.test_snapshotting") + snapshots = tab.history() + assert len(snapshots) == 2 + + for snapshot in snapshots: + daft_pandas = daft.read_iceberg(tab, snapshot_id=snapshot.snapshot_id).to_pandas() + iceberg_pandas = tab.scan(snapshot_id=snapshot.snapshot_id).to_pandas() + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[])