From a55f79d7593a4f0bf07c81cb32c6e8a22a878805 Mon Sep 17 00:00:00 2001 From: J van Zundert Date: Thu, 1 Jun 2023 06:48:52 +0100 Subject: [PATCH] fix(python): Fix DataFrame.to_arrow() for 0x0 dataframes (#9144) --- py-polars/polars/dataframe/frame.py | 7 +++++-- py-polars/tests/unit/test_interop.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index b8699c08d3c6..267d51460ac1 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -1845,8 +1845,11 @@ def to_arrow(self) -> pa.Table: bar: [["a","b","c","d","e","f"]] """ - record_batches = self._df.to_arrow() - return pa.Table.from_batches(record_batches) + if self.shape[1]: # all except 0x0 dataframe + record_batches = self._df.to_arrow() + return pa.Table.from_batches(record_batches) + else: # 0x0 dataframe, cannot infer schema from batches + return pa.table({}) @overload def to_dict(self, as_series: Literal[True] = ...) -> dict[str, Series]: diff --git a/py-polars/tests/unit/test_interop.py b/py-polars/tests/unit/test_interop.py index d1ff18876f85..059ad1d4b5ef 100644 --- a/py-polars/tests/unit/test_interop.py +++ b/py-polars/tests/unit/test_interop.py @@ -220,6 +220,23 @@ def test_arrow_null_roundtrip() -> None: assert c1.to_pylist() == c2.to_pylist() +def test_arrow_empty_dataframe() -> None: + # 0x0 dataframe + df = pl.DataFrame({}) + tbl = pa.table({}) + assert df.to_arrow() == tbl + df2 = cast(pl.DataFrame, pl.from_arrow(df.to_arrow())) + assert_frame_equal(df2, df) + + # 0 row dataframe + df = pl.DataFrame({}, schema={"a": pl.Int32}) + tbl = pa.Table.from_batches([], pa.schema([pa.field("a", pa.int32())])) + assert df.to_arrow() == tbl + df2 = cast(pl.DataFrame, pl.from_arrow(df.to_arrow())) + assert df2.schema == {"a": pl.Int32} + assert df2.shape == (0, 1) + + def test_arrow_dict_to_polars() -> None: pa_dict = pa.DictionaryArray.from_arrays( indices=np.array([0, 1, 2, 3, 1, 0, 2, 3, 3, 2]),