-
Notifications
You must be signed in to change notification settings - Fork 608
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): improve efficiency of
__dataframe__
protocol
- Loading branch information
Showing
5 changed files
with
313 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
from __future__ import annotations | ||
|
||
import pyarrow as pa | ||
import pytest | ||
|
||
pytestmark = pytest.mark.skipif(pa.__version__ < "12.", reason="pyarrow >= 12 required") | ||
|
||
|
||
@pytest.mark.notimpl(["dask", "druid"]) | ||
@pytest.mark.notimpl( | ||
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor" | ||
) | ||
def test_dataframe_interchange_no_execute(con, alltypes, mocker): | ||
t = alltypes.select("int_col", "double_col", "string_col") | ||
pa_df = t.to_pyarrow().__dataframe__() | ||
|
||
to_pyarrow = mocker.spy(con, "to_pyarrow") | ||
|
||
df = t.__dataframe__() | ||
|
||
# Schema metadata | ||
assert df.num_columns() == pa_df.num_columns() | ||
assert df.column_names() == pa_df.column_names() | ||
|
||
# Column access | ||
assert df.get_column(0).dtype == pa_df.get_column(0).dtype | ||
assert ( | ||
df.get_column_by_name("int_col").dtype | ||
== pa_df.get_column_by_name("int_col").dtype | ||
) | ||
res = [c.dtype for c in df.get_columns()] | ||
sol = [c.dtype for c in pa_df.get_columns()] | ||
assert res == sol | ||
col = df.get_column(0) | ||
with pytest.raises( | ||
TypeError, match="only works on a column with categorical dtype" | ||
): | ||
col.describe_categorical # noqa: B018 | ||
|
||
# Subselection | ||
res = df.select_columns([1, 0]) | ||
sol = pa_df.select_columns([1, 0]) | ||
assert res.column_names() == sol.column_names() | ||
res = df.select_columns_by_name(["double_col", "int_col"]) | ||
sol = pa_df.select_columns_by_name(["double_col", "int_col"]) | ||
assert res.column_names() == sol.column_names() | ||
|
||
# Nested __dataframe__ access | ||
df2 = df.__dataframe__() | ||
pa_df2 = pa_df.__dataframe__() | ||
assert df2.column_names() == pa_df2.column_names() | ||
|
||
assert not to_pyarrow.called | ||
|
||
|
||
@pytest.mark.notimpl(["dask"]) | ||
@pytest.mark.notimpl( | ||
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor" | ||
) | ||
def test_dataframe_interchange_dataframe_methods_execute(con, alltypes, mocker): | ||
t = alltypes.select("int_col", "double_col", "string_col") | ||
pa_df = t.to_pyarrow().__dataframe__() | ||
|
||
to_pyarrow = mocker.spy(con, "to_pyarrow") | ||
|
||
df = t.__dataframe__() | ||
|
||
assert to_pyarrow.call_count == 0 | ||
assert df.metadata == pa_df.metadata | ||
assert df.num_rows() == pa_df.num_rows() | ||
assert df.num_chunks() == pa_df.num_chunks() | ||
assert len(list(df.get_chunks())) == df.num_chunks() | ||
assert to_pyarrow.call_count == 1 | ||
|
||
|
||
@pytest.mark.notimpl(["dask", "druid"]) | ||
@pytest.mark.notimpl( | ||
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor" | ||
) | ||
def test_dataframe_interchange_column_methods_execute(con, alltypes, mocker): | ||
t = alltypes.select("int_col", "double_col", "string_col") | ||
pa_df = t.to_pyarrow().__dataframe__() | ||
|
||
to_pyarrow = mocker.spy(con, "to_pyarrow") | ||
|
||
df = t.__dataframe__() | ||
col = df.get_column(0) | ||
pa_col = pa_df.get_column(0) | ||
|
||
assert to_pyarrow.call_count == 0 | ||
assert col.size() == pa_col.size() | ||
assert col.offset == pa_col.offset | ||
|
||
assert col.describe_null == pa_col.describe_null | ||
assert col.null_count == pa_col.null_count | ||
assert col.metadata == pa_col.metadata | ||
assert col.num_chunks() == pa_col.num_chunks() | ||
assert len(list(col.get_chunks())) == pa_col.num_chunks() | ||
assert len(list(col.get_buffers())) == len(list(pa_col.get_buffers())) | ||
assert to_pyarrow.call_count == 1 | ||
|
||
# Access another column doesn't execute | ||
col2 = df.get_column(1) | ||
pa_col2 = pa_df.get_column(1) | ||
assert col2.size() == pa_col2.size() | ||
|
||
|
||
@pytest.mark.notimpl(["dask"]) | ||
@pytest.mark.notimpl( | ||
["impala"], raises=AttributeError, reason="missing `fetchmany` on the cursor" | ||
) | ||
def test_dataframe_interchange_select_after_execution_no_reexecute( | ||
con, alltypes, mocker | ||
): | ||
t = alltypes.select("int_col", "double_col", "string_col") | ||
pa_df = t.to_pyarrow().__dataframe__() | ||
|
||
to_pyarrow = mocker.spy(con, "to_pyarrow") | ||
|
||
df = t.__dataframe__() | ||
|
||
# An operation that requires loading data | ||
assert to_pyarrow.call_count == 0 | ||
assert df.num_rows() == pa_df.num_rows() | ||
assert to_pyarrow.call_count == 1 | ||
|
||
# Subselect columns doesn't reexecute | ||
df2 = df.select_columns([1, 0]) | ||
pa_df2 = pa_df.select_columns([1, 0]) | ||
assert df2.num_rows() == pa_df2.num_rows() | ||
assert df2.column_names() == pa_df2.column_names() | ||
assert to_pyarrow.call_count == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
from __future__ import annotations | ||
|
||
from functools import cached_property | ||
from typing import TYPE_CHECKING, Sequence | ||
|
||
if TYPE_CHECKING: | ||
import ibis.expr.types as ir | ||
import pyarrow as pa | ||
|
||
|
||
class IbisDataFrame: | ||
"""An implementation of the dataframe interchange protocol. | ||
This is a thin shim around the pyarrow implementation to allow for: | ||
- Accessing a few of the metadata queries without executing the expression. | ||
- Caching the execution on the dataframe object to avoid re-execution if | ||
multiple methods are accessed. | ||
The dataframe interchange protocol may be found here: | ||
https://data-apis.org/dataframe-protocol/latest/API.html | ||
""" | ||
|
||
def __init__( | ||
self, | ||
table: ir.Table, | ||
nan_as_null: bool = False, | ||
allow_copy: bool = True, | ||
pyarrow_table: pa.Table | None = None, | ||
): | ||
self._table = table | ||
self._nan_as_null = nan_as_null | ||
self._allow_copy = allow_copy | ||
self._pyarrow_table = pyarrow_table | ||
|
||
@cached_property | ||
def _pyarrow_df(self): | ||
"""Returns the pyarrow implementation of the __dataframe__ protocol. | ||
If the backing ibis Table hasn't been executed yet, this will result | ||
in executing and caching the result.""" | ||
if self._pyarrow_table is None: | ||
self._pyarrow_table = self._table.to_pyarrow() | ||
return self._pyarrow_table.__dataframe__( | ||
nan_as_null=self._nan_as_null, | ||
allow_copy=self._allow_copy, | ||
) | ||
|
||
@cached_property | ||
def _empty_pyarrow_df(self): | ||
"""A pyarrow implementation of the __dataframe__ protocol for an | ||
empty table with the same schema as this table. | ||
Used for returning dtype information without executing the backing ibis | ||
expression. | ||
""" | ||
return self._table.schema().to_pyarrow().empty_table().__dataframe__() | ||
|
||
def _get_dtype(self, name): | ||
"""Get the dtype info for a column named `name`.""" | ||
return self._empty_pyarrow_df.get_column_by_name(name).dtype | ||
|
||
# These methods may all be handled without executing the query | ||
def num_columns(self): | ||
return len(self._table.columns) | ||
|
||
def column_names(self): | ||
return self._table.columns | ||
|
||
def get_column(self, i: int) -> IbisColumn: | ||
name = self._table.columns[i] | ||
return self.get_column_by_name(name) | ||
|
||
def get_column_by_name(self, name: str) -> IbisColumn: | ||
return IbisColumn(self, name) | ||
|
||
def get_columns(self): | ||
return [IbisColumn(self, name) for name in self._table.columns] | ||
|
||
def select_columns(self, indices: Sequence[int]) -> IbisDataFrame: | ||
names = [self._table.columns[i] for i in indices] | ||
return self.select_columns_by_name(names) | ||
|
||
def select_columns_by_name(self, names: Sequence[str]) -> IbisDataFrame: | ||
names = list(names) | ||
table = self._table.select(names) | ||
if (pyarrow_table := self._pyarrow_table) is not None: | ||
pyarrow_table = pyarrow_table.select(names) | ||
return IbisDataFrame( | ||
table, | ||
nan_as_null=self._nan_as_null, | ||
allow_copy=self._allow_copy, | ||
pyarrow_table=pyarrow_table, | ||
) | ||
|
||
def __dataframe__( | ||
self, nan_as_null: bool = False, allow_copy: bool = True | ||
) -> IbisDataFrame: | ||
return IbisDataFrame( | ||
self._table, | ||
nan_as_null=nan_as_null, | ||
allow_copy=allow_copy, | ||
pyarrow_table=self._pyarrow_table, | ||
) | ||
|
||
# These methods require executing the query | ||
@property | ||
def metadata(self): | ||
return self._pyarrow_df.metadata | ||
|
||
def num_rows(self) -> int | None: | ||
return self._pyarrow_df.num_rows() | ||
|
||
def num_chunks(self) -> int: | ||
return self._pyarrow_df.num_chunks() | ||
|
||
def get_chunks(self, n_chunks: int | None = None): | ||
return self._pyarrow_df.get_chunks(n_chunks=n_chunks) | ||
|
||
|
||
class IbisColumn: | ||
def __init__(self, df: IbisDataFrame, name: str): | ||
self._df = df | ||
self._name = name | ||
|
||
@cached_property | ||
def _pyarrow_col(self): | ||
"""Returns the pyarrow implementation of the __dataframe__ protocol's | ||
Column type. | ||
If the backing ibis Table hasn't been executed yet, this will result | ||
in executing and caching the result.""" | ||
return self._df._pyarrow_df.get_column_by_name(self._name) | ||
|
||
# These methods may all be handled without executing the query | ||
@property | ||
def dtype(self): | ||
return self._df._get_dtype(self._name) | ||
|
||
@property | ||
def describe_categorical(self): | ||
raise TypeError( | ||
"describe_categorical only works on a column with categorical dtype" | ||
) | ||
|
||
# These methods require executing the query | ||
def size(self): | ||
return self._pyarrow_col.size() | ||
|
||
@property | ||
def offset(self): | ||
return self._pyarrow_col.offset | ||
|
||
@property | ||
def describe_null(self): | ||
return self._pyarrow_col.describe_null | ||
|
||
@property | ||
def null_count(self): | ||
return self._pyarrow_col.null_count | ||
|
||
@property | ||
def metadata(self): | ||
return self._pyarrow_col.metadata | ||
|
||
def num_chunks(self) -> int: | ||
return self._pyarrow_col.num_chunks() | ||
|
||
def get_chunks(self, n_chunks: int | None = None): | ||
return self._pyarrow_col.get_chunks(n_chunks=n_chunks) | ||
|
||
def get_buffers(self): | ||
return self._pyarrow_col.get_buffers() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
select = ["D102"] | ||
show-source = true | ||
|
||
[per-file-ignores] | ||
"dataframe_interchange.py" = ["D102"] |