From a9d6fdf0101f7069d1f5be6bf6e20bea1f58f5f3 Mon Sep 17 00:00:00 2001 From: "R. Tyler Croy" Date: Mon, 30 Dec 2024 21:34:07 +0000 Subject: [PATCH] fix: introduce a reproduction case for List casting with polars See #3063 Signed-off-by: R. Tyler Croy --- .github/workflows/python_build.yml | 2 +- python/Makefile | 8 +++--- python/pyproject.toml | 2 ++ python/tests/test_writer.py | 40 ++++++++++++++++++++++++++++++ 4 files changed, 47 insertions(+), 5 deletions(-) diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml index 2098485853..a76cc1f472 100644 --- a/.github/workflows/python_build.yml +++ b/.github/workflows/python_build.yml @@ -46,7 +46,7 @@ jobs: - name: Build and install deltalake run: | # Install minimum PyArrow version - uv sync --extra devel --extra pandas + uv sync --extra devel --extra pandas --extra polars uv pip install pyarrow==16.0.0 env: RUSTFLAGS: "-C debuginfo=line-tables-only" diff --git a/python/Makefile b/python/Makefile index f342f333a7..4bf299e452 100644 --- a/python/Makefile +++ b/python/Makefile @@ -8,7 +8,7 @@ DAT_VERSION := 0.0.2 .PHONY: setup setup: ## Setup the requirements $(info --- Setup dependencies ---) - uv sync --extra devel --extra pandas + uv sync --extra devel --extra pandas --extra polars .PHONY: setup-dat setup-dat: ## Download DAT test files @@ -28,7 +28,7 @@ build: setup ## Build Python binding of delta-rs .PHONY: develop develop: setup ## Install Python binding of delta-rs $(info --- Develop with Python binding ---) - uvx --from 'maturin[zig]' maturin develop --extras=devel,pandas $(MATURIN_EXTRA_ARGS) + uvx --from 'maturin[zig]' maturin develop --extras=devel,pandas,polars $(MATURIN_EXTRA_ARGS) .PHONY: install install: build ## Install Python binding of delta-rs @@ -36,13 +36,13 @@ install: build ## Install Python binding of delta-rs uv pip uninstall deltalake $(info --- Install Python binding ---) $(eval TARGET_WHEEL := $(shell ls ../target/wheels/deltalake-${PACKAGE_VERSION}-*.whl)) - uv pip install $(TARGET_WHEEL)[devel,pandas] + uv pip install $(TARGET_WHEEL)[devel,pandas,polars] .PHONY: develop-pyspark develop-pyspark: uv sync --all-extras $(info --- Develop with Python binding ---) - uvx --from 'maturin[zig]' maturin develop --extras=devel,pandas,pyspark $(MATURIN_EXTRA_ARGS) + uvx --from 'maturin[zig]' maturin develop --extras=devel,pandas,polars,pyspark $(MATURIN_EXTRA_ARGS) .PHONY: format format: ## Format the code diff --git a/python/pyproject.toml b/python/pyproject.toml index b937109930..8968d0c609 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -38,6 +38,7 @@ devel = [ "mypy==1.10.1", "ruff==0.5.2", ] +polars = ["polars==1.17.1"] pyspark = [ "pyspark", "delta-spark", @@ -93,6 +94,7 @@ markers = [ "s3: marks tests as integration tests with S3 (deselect with '-m \"not s3\"')", "azure: marks tests as integration tests with Azure Blob Store", "pandas: marks tests that require pandas", + "polars: marks tests that require polars", "pyspark: marks tests that require pyspark", ] diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index a6662c48d6..e58cb545ea 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -2025,3 +2025,43 @@ def test_write_transactions(tmp_path: pathlib.Path, sample_data: pa.Table): assert transaction_2.app_id == "app_2" assert transaction_2.version == 2 assert transaction_2.last_updated == 123456 + + +# +@pytest.mark.polars +def test_write_structs(tmp_path: pathlib.Path): + import polars as pl + + dt = DeltaTable.create( + tmp_path, + schema=pa.schema( + [ + ("a", pa.int32()), + ("b", pa.string()), + ("c", pa.struct({"d": pa.int16(), "e": pa.int16()})), + ] + ), + ) + + df = pl.DataFrame( + { + "a": [0, 1], + "b": ["x", "y"], + "c": [ + {"d": -55, "e": -32}, + {"d": 0, "e": 0}, + ], + } + ) + + dt.merge( + source=df.to_arrow(), + predicate=" AND ".join([f"target.{x} = source.{x}" for x in ["a"]]), + source_alias="source", + target_alias="target", + large_dtypes=False, + ).when_not_matched_insert_all().execute() + + arrow_dt = dt.to_pyarrow_dataset() + new_df = pl.scan_pyarrow_dataset(arrow_dt) + new_df.collect()