Skip to content

Commit

Permalink
feat(datafusion): implement in-memory table
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 25, 2023
1 parent 05b02da commit d4ec5c2
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 21 deletions.
22 changes: 22 additions & 0 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ def dummy_table(op, ctx, **kw):
)


@translate.register(ops.InMemoryTable)
def in_memory_table(op, ctx, **kw):
schema = op.schema

if data := op.data:
return ctx.from_arrow_table(data.to_pyarrow(schema), name=op.name)

# datafusion panics when given an empty table
return (
ctx.empty_table()
.select(
*(
translate(
ops.Alias(ops.Literal(None, dtype=dtype), name), ctx=ctx, **kw
)
for name, dtype in schema.items()
)
)
.limit(0)
)


@translate.register(ops.Alias)
def alias(op, **kw):
arg = translate(op.arg, **kw)
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,13 +778,11 @@ def test_invalid_connect(tmp_path):
),
],
)
@pytest.mark.notimpl(["datafusion"])
def test_in_memory_table(backend, con, expr, expected):
result = con.execute(expr)
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["datafusion"])
def test_filter_memory_table(backend, con):
t = ibis.memtable([(1, 2), (3, 4), (5, 6)], columns=["x", "y"])
expr = t.filter(t.x > 1)
Expand All @@ -793,7 +791,6 @@ def test_filter_memory_table(backend, con):
backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(["datafusion"])
def test_agg_memory_table(con):
t = ibis.memtable([(1, 2), (3, 4), (5, 6)], columns=["x", "y"])
expr = t.x.count()
Expand Down
14 changes: 2 additions & 12 deletions ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import ibis
import ibis.expr.datatypes as dt
from ibis import util
from ibis.common.exceptions import OperationNotDefinedError

pa = pytest.importorskip("pyarrow")

Expand Down Expand Up @@ -194,15 +193,15 @@ def test_to_pyarrow_batches_borked_types(batting):
util.consume(batch_reader)


@pytest.mark.notimpl(["dask", "datafusion", "impala", "pyspark"])
@pytest.mark.notimpl(["dask", "impala", "pyspark"])
def test_to_pyarrow_memtable(con):
expr = ibis.memtable({"x": [1, 2, 3]})
table = con.to_pyarrow(expr)
assert isinstance(table, pa.Table)
assert len(table) == 3


@pytest.mark.notimpl(["dask", "datafusion", "impala", "pyspark"])
@pytest.mark.notimpl(["dask", "impala", "pyspark"])
def test_to_pyarrow_batches_memtable(con):
expr = ibis.memtable({"x": [1, 2, 3]})
n = 0
Expand Down Expand Up @@ -269,10 +268,6 @@ def test_roundtrip_partitioned_parquet(tmp_path, con, backend, awards_players):
@pytest.mark.notimpl(
["dask", "impala", "pyspark"], reason="No support for exporting files"
)
@pytest.mark.notimpl(
["datafusion"],
reason="No memtable support",
)
@pytest.mark.parametrize("ftype", ["csv", "parquet"])
def test_memtable_to_file(tmp_path, con, ftype, monkeypatch):
"""
Expand Down Expand Up @@ -456,11 +451,6 @@ def test_to_torch(alltypes):
non_numeric.to_torch()


@pytest.mark.notimpl(
["datafusion"],
raises=OperationNotDefinedError,
reason="InMemoryTable not yet implemented for the datafusion backend",
)
def test_empty_memtable(backend, con):
expected = pd.DataFrame({"a": []})
table = ibis.memtable(expected)
Expand Down
10 changes: 4 additions & 6 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,15 +917,13 @@ def test_literal_na(con, dtype):
assert pd.isna(result)


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
def test_memtable_bool_column(backend, con, monkeypatch):
monkeypatch.setattr(ibis.options, "default_backend", con)

def test_memtable_bool_column(backend, con):
t = ibis.memtable({"a": [True, False, True]})
backend.assert_series_equal(t.a.execute(), pd.Series([True, False, True], name="a"))
backend.assert_series_equal(
con.execute(t.a), pd.Series([True, False, True], name="a")
)


@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@pytest.mark.broken(
["druid"],
raises=AssertionError,
Expand Down

0 comments on commit d4ec5c2

Please sign in to comment.