diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index e32612433210..bc70c7658c46 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -41,6 +41,28 @@ def dummy_table(op, ctx, **kw): ) +@translate.register(ops.InMemoryTable) +def in_memory_table(op, ctx, **kw): + schema = op.schema + + if data := op.data: + return ctx.from_arrow_table(data.to_pyarrow(schema), name=op.name) + + # datafusion panics when given an empty table + return ( + ctx.empty_table() + .select( + *( + translate( + ops.Alias(ops.Literal(None, dtype=dtype), name), ctx=ctx, **kw + ) + for name, dtype in schema.items() + ) + ) + .limit(0) + ) + + @translate.register(ops.Alias) def alias(op, **kw): arg = translate(op.arg, **kw) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 239693534162..cedaee1ad23f 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -778,13 +778,11 @@ def test_invalid_connect(tmp_path): ), ], ) -@pytest.mark.notimpl(["datafusion"]) def test_in_memory_table(backend, con, expr, expected): result = con.execute(expr) backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["datafusion"]) def test_filter_memory_table(backend, con): t = ibis.memtable([(1, 2), (3, 4), (5, 6)], columns=["x", "y"]) expr = t.filter(t.x > 1) @@ -793,7 +791,6 @@ def test_filter_memory_table(backend, con): backend.assert_frame_equal(result, expected) -@pytest.mark.notimpl(["datafusion"]) def test_agg_memory_table(con): t = ibis.memtable([(1, 2), (3, 4), (5, 6)], columns=["x", "y"]) expr = t.x.count() diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index 1fe6580b1197..fd071c4018a3 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -10,7 +10,6 @@ import ibis import ibis.expr.datatypes as dt from ibis import util -from ibis.common.exceptions import OperationNotDefinedError pa = pytest.importorskip("pyarrow") @@ -194,7 +193,7 @@ def test_to_pyarrow_batches_borked_types(batting): util.consume(batch_reader) -@pytest.mark.notimpl(["dask", "datafusion", "impala", "pyspark"]) +@pytest.mark.notimpl(["dask", "impala", "pyspark"]) def test_to_pyarrow_memtable(con): expr = ibis.memtable({"x": [1, 2, 3]}) table = con.to_pyarrow(expr) @@ -202,7 +201,7 @@ def test_to_pyarrow_memtable(con): assert len(table) == 3 -@pytest.mark.notimpl(["dask", "datafusion", "impala", "pyspark"]) +@pytest.mark.notimpl(["dask", "impala", "pyspark"]) def test_to_pyarrow_batches_memtable(con): expr = ibis.memtable({"x": [1, 2, 3]}) n = 0 @@ -269,10 +268,6 @@ def test_roundtrip_partitioned_parquet(tmp_path, con, backend, awards_players): @pytest.mark.notimpl( ["dask", "impala", "pyspark"], reason="No support for exporting files" ) -@pytest.mark.notimpl( - ["datafusion"], - reason="No memtable support", -) @pytest.mark.parametrize("ftype", ["csv", "parquet"]) def test_memtable_to_file(tmp_path, con, ftype, monkeypatch): """ @@ -456,11 +451,6 @@ def test_to_torch(alltypes): non_numeric.to_torch() -@pytest.mark.notimpl( - ["datafusion"], - raises=OperationNotDefinedError, - reason="InMemoryTable not yet implemented for the datafusion backend", -) def test_empty_memtable(backend, con): expected = pd.DataFrame({"a": []}) table = ibis.memtable(expected) diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index fd979aa09f31..85c67350d9dc 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -917,15 +917,13 @@ def test_literal_na(con, dtype): assert pd.isna(result) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) -def test_memtable_bool_column(backend, con, monkeypatch): - monkeypatch.setattr(ibis.options, "default_backend", con) - +def test_memtable_bool_column(backend, con): t = ibis.memtable({"a": [True, False, True]}) - backend.assert_series_equal(t.a.execute(), pd.Series([True, False, True], name="a")) + backend.assert_series_equal( + con.execute(t.a), pd.Series([True, False, True], name="a") + ) -@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @pytest.mark.broken( ["druid"], raises=AssertionError,