diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 807c8e4b3d06..e65cc23dee95 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -47,6 +47,7 @@ def in_memory_table(op, ctx, **kw): schema = op.schema if data := op.data: + ctx.deregister_table(op.name) return ctx.from_arrow_table(data.to_pyarrow(schema), name=op.name) # datafusion panics when given an empty table @@ -92,8 +93,7 @@ def cast(op, **kw): @translate.register(ops.TableColumn) def column(op, **_): - id_parts = [getattr(op.table, "name", None), op.name] - return df.column(".".join(f'"{id}"' for id in id_parts if id)) + return df.column(f'"{op.name}"') @translate.register(ops.SortKey) @@ -829,3 +829,59 @@ def extract_query(op, **kw): if op.key is not None else extract_query_udf(arg) ) + + +_join_types = { + ops.InnerJoin: "inner", + ops.LeftJoin: "left", + ops.RightJoin: "right", + ops.OuterJoin: "full", + ops.LeftAntiJoin: "anti", + ops.LeftSemiJoin: "semi", +} + + +@translate.register(ops.Join) +def join(op, **kw): + left = translate(op.left, **kw) + right = translate(op.right, **kw) + + right_table = op.right + if isinstance(op, ops.RightJoin): + how = "left" + right_table = op.left + left, right = right, left + else: + how = _join_types[type(op)] + + left_cols = set(left.schema().names) + right_cols = {} + for col in right.schema().names: + if col in left_cols: + right_cols[col] = f"{col}_right" + else: + right_cols[col] = f"{col}" + + left_keys, right_keys = [], [] + for pred in op.predicates: + if isinstance(pred, ops.Equals): + left_keys.append(f'"{pred.left.name}"') + right_keys.append(f'"{right_cols[pred.right.name]}"') + else: + raise com.TranslationError( + "DataFusion backend is unable to compile join predicate " + f"with operation type of {type(pred)}" + ) + + right = translate( + ops.Selection( + right_table, + [ + ops.Alias(ops.TableColumn(right_table, key), value) + for key, value in right_cols.items() + ], + ), + **kw, + ) + + return left.join(right, join_keys=(left_keys, right_keys), how=how) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 92794fff5b49..e6f2022555ec 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1269,7 +1269,6 @@ def test_topk_op(alltypes, df): ) ], ) -@mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError) @mark.broken( ["bigquery"], raises=GoogleBadRequest, diff --git a/ibis/backends/tests/test_join.py b/ibis/backends/tests/test_join.py index 5ae3f8dfa9ac..7f3b7b5d7dbb 100644 --- a/ibis/backends/tests/test_join.py +++ b/ibis/backends/tests/test_join.py @@ -59,13 +59,13 @@ def check_eq(left, right, how, **kwargs): # syntax, but we might be able to work around that using # LEFT JOIN UNION RIGHT JOIN marks=pytest.mark.notimpl( - ["mysql"] + ["mysql", "datafusion"] + (["sqlite"] * (vparse(sqlite3.sqlite_version) < vparse("3.39"))) ), ), ], ) -@pytest.mark.notimpl(["datafusion", "druid"]) +@pytest.mark.notimpl(["druid"]) @pytest.mark.xfail_version( polars=["polars>=0.18.6,<0.18.8"], reason="https://github.com/pola-rs/polars/issues/9955", @@ -119,7 +119,7 @@ def test_mutating_join(backend, batting, awards_players, how): @pytest.mark.parametrize("how", ["semi", "anti"]) -@pytest.mark.notimpl(["bigquery", "dask", "datafusion", "druid"]) +@pytest.mark.notimpl(["bigquery", "dask", "druid"]) def test_filtering_join(backend, batting, awards_players, how): left = batting[batting.yearID == 2015] right = awards_players[awards_players.lgID == "NL"].drop("yearID", "lgID") @@ -148,7 +148,6 @@ def test_filtering_join(backend, batting, awards_players, how): backend.assert_frame_equal(result, expected, check_like=True) -@pytest.mark.notimpl(["datafusion"]) @pytest.mark.broken( ["polars"], raises=ValueError, @@ -166,7 +165,6 @@ def test_join_then_filter_no_column_overlap(awards_players, batting): assert not q.execute().empty -@pytest.mark.notimpl(["datafusion"]) @pytest.mark.broken( ["polars"], raises=ValueError, @@ -180,7 +178,7 @@ def test_mutate_then_join_no_column_overlap(batting, awards_players): assert not expr.limit(5).execute().empty -@pytest.mark.notimpl(["datafusion", "bigquery", "druid"]) +@pytest.mark.notimpl(["bigquery", "druid"]) @pytest.mark.notyet(["dask"], reason="dask doesn't support descending order by") @pytest.mark.broken( ["polars"], @@ -205,7 +203,7 @@ def test_semi_join_topk(batting, awards_players, func): assert not expr.limit(5).execute().empty -@pytest.mark.notimpl(["dask", "datafusion", "druid"]) +@pytest.mark.notimpl(["dask", "druid"]) def test_join_with_pandas(batting, awards_players): batting_filt = batting[lambda t: t.yearID < 1900] awards_players_filt = awards_players[lambda t: t.yearID < 1900].execute() @@ -215,14 +213,14 @@ def test_join_with_pandas(batting, awards_players): assert df.yearID.nunique() == 7 -@pytest.mark.notimpl(["dask", "datafusion"]) +@pytest.mark.notimpl(["dask"]) def test_join_with_pandas_non_null_typed_columns(batting, awards_players): batting_filt = batting[lambda t: t.yearID < 1900][["yearID"]] awards_players_filt = awards_players[lambda t: t.yearID < 1900][ ["yearID"] ].execute() - # ensure that none of the columns of eitherr table have type null + # ensure that none of the columns of either table have type null batting_schema = batting_filt.schema() assert len(batting_schema) == 1 assert batting_schema["yearID"].is_integer() @@ -297,10 +295,7 @@ def test_join_with_pandas_non_null_typed_columns(batting, awards_players): ], ) @pytest.mark.notimpl( - ["datafusion"], raises=com.OperationNotDefinedError, reason="joins not implemented" -) -@pytest.mark.notimpl( - ["polars"], + ["polars", "datafusion"], raises=com.TranslationError, reason="polars doesn't support join predicates", )