diff --git a/ibis/backends/base/sql/compiler/select_builder.py b/ibis/backends/base/sql/compiler/select_builder.py index 01a2df46d862..a84dddf27fa2 100644 --- a/ibis/backends/base/sql/compiler/select_builder.py +++ b/ibis/backends/base/sql/compiler/select_builder.py @@ -306,6 +306,11 @@ def _collect_PandasInMemoryTable(self, node, toplevel=False): self.select_set = [node] self.table_set = node + def _collect_PyArrowInMemoryTable(self, node, toplevel=False): + if toplevel: + self.select_set = [node] + self.table_set = node + def _convert_group_by(self, nodes): return list(range(len(nodes))) diff --git a/ibis/backends/pyarrow/__init__.py b/ibis/backends/pyarrow/__init__.py index e69de29bb2d1..7ebdbfdd2f39 100644 --- a/ibis/backends/pyarrow/__init__.py +++ b/ibis/backends/pyarrow/__init__.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pyarrow as pa + +import ibis.expr.operations as ops +import ibis.expr.rules as rlz +import ibis.expr.schema as sch +from ibis import util +from ibis.common.grounds import Immutable + +if TYPE_CHECKING: + import pandas as pd + + +class PyArrowTableProxy(Immutable, util.ToFrame): + __slots__ = ('_t', '_hash') + + def __init__(self, t: pa.Table) -> None: + object.__setattr__(self, "_t", t) + object.__setattr__(self, "_hash", hash((type(t), id(t)))) + + def __hash__(self) -> int: + return self._hash + + def __repr__(self) -> str: + df_repr = util.indent(repr(self._t), spaces=2) + return f"{self.__class__.__name__}:\n{df_repr}" + + def to_frame(self) -> pd.DataFrame: + return self._t.to_pandas() + + def to_pyarrow(self, _: sch.Schema) -> pa.Table: + return self._t + + +class PyArrowInMemoryTable(ops.InMemoryTable): + data = rlz.instance_of(PyArrowTableProxy) + + +@sch.infer.register(pa.Table) +def infer_pyarrow_table_schema(t: pa.Table, schema=None): + import ibis.backends.pyarrow.datatypes # noqa: F401 + + return sch.schema(schema if schema is not None else t.schema) diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 0f7fab9208ee..e3576c9829ab 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -20,6 +20,7 @@ from ibis.backends.base.df.timecontext import adjust_context from ibis.backends.pandas.client import PandasInMemoryTable from ibis.backends.pandas.execution import execute +from ibis.backends.pyarrow import PyArrowInMemoryTable from ibis.backends.pyspark.datatypes import spark_dtype from ibis.backends.pyspark.timecontext import ( combine_time_context, @@ -1862,8 +1863,8 @@ def compile_random(*args, **kwargs): return F.rand() -@compiles(ops.InMemoryTable) @compiles(PandasInMemoryTable) +@compiles(PyArrowInMemoryTable) def compile_in_memory_table(t, op, session, **kwargs): fields = [ pt.StructField(name, spark_dtype(dtype), dtype.nullable) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 3048f5e3ffd3..85c3430138b8 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -920,6 +920,29 @@ def test_memtable_bool_column(backend, con, monkeypatch): backend.assert_series_equal(t.a.execute(), pd.Series([True, False, True], name="a")) +@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["dask", "pandas"], raises=com.UnboundExpressionError) +@pytest.mark.broken( + ["druid"], + raises=AssertionError, + reason="result contains empty strings instead of None", +) +def test_memtable_construct(backend, con, monkeypatch): + pa = pytest.importorskip("pyarrow") + monkeypatch.setattr(ibis.options, "default_backend", con) + + pa_t = pa.Table.from_pydict( + { + "a": list("abc"), + "b": [1, 2, 3], + "c": [1.0, 2.0, 3.0], + "d": [None, "b", None], + } + ) + t = ibis.memtable(pa_t) + backend.assert_frame_equal(t.execute(), pa_t.to_pandas()) + + @pytest.mark.notimpl( ["dask", "datafusion", "pandas", "polars"], raises=NotImplementedError, diff --git a/ibis/expr/api.py b/ibis/expr/api.py index ef4c55e92a93..82002f3c0833 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: import pandas as pd + import pyarrow as pa from ibis.common.typing import SupportsSchema @@ -324,10 +325,10 @@ def memtable( Parameters ---------- data - Any data accepted by the `pandas.DataFrame` constructor. + Any data accepted by the `pandas.DataFrame` constructor or a `pyarrow.Table`. - The use of `DataFrame` underneath should **not** be relied upon and is - free to change across non-major releases. + Do not depend on the underlying storage type (e.g., pyarrow.Table), it's subject + to change across non-major releases. columns Optional [`Iterable`][typing.Iterable] of [`str`][str] column names. schema @@ -393,6 +394,15 @@ def memtable( "passing `columns` and schema` is ambiguous; " "pass one or the other but not both" ) + + try: + import pyarrow as pa + except ImportError: + pass + else: + if isinstance(data, pa.Table): + return _memtable_from_pyarrow_table(data, name=name, schema=schema) + df = pd.DataFrame(data, columns=columns) if df.columns.inferred_type != "string": cols = df.columns @@ -421,6 +431,18 @@ def _memtable_from_dataframe( return op.to_expr() +def _memtable_from_pyarrow_table( + data: pa.Table, *, name: str | None = None, schema: SupportsSchema | None = None +): + from ibis.backends.pyarrow import PyArrowInMemoryTable, PyArrowTableProxy + + return PyArrowInMemoryTable( + name=name if name is not None else util.generate_unique_table_name("memtable"), + schema=sch.infer(data) if schema is None else schema, + data=PyArrowTableProxy(data), + ).to_expr() + + def _deferred_method_call(expr, method_name): method = operator.methodcaller(method_name) if isinstance(expr, str):