Skip to content

Commit

Permalink
feat(datafusion): add join support
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Aug 10, 2023
1 parent 914a9a3 commit e2c143a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 16 deletions.
60 changes: 58 additions & 2 deletions ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def in_memory_table(op, ctx, **kw):
schema = op.schema

if data := op.data:
ctx.deregister_table(op.name)
return ctx.from_arrow_table(data.to_pyarrow(schema), name=op.name)

# datafusion panics when given an empty table
Expand Down Expand Up @@ -92,8 +93,7 @@ def cast(op, **kw):

@translate.register(ops.TableColumn)
def column(op, **_):
id_parts = [getattr(op.table, "name", None), op.name]
return df.column(".".join(f'"{id}"' for id in id_parts if id))
return df.column(f'"{op.name}"')


@translate.register(ops.SortKey)
Expand Down Expand Up @@ -829,3 +829,59 @@ def extract_query(op, **kw):
if op.key is not None
else extract_query_udf(arg)
)


_join_types = {
ops.InnerJoin: "inner",
ops.LeftJoin: "left",
ops.RightJoin: "right",
ops.OuterJoin: "full",
ops.LeftAntiJoin: "anti",
ops.LeftSemiJoin: "semi",
}


@translate.register(ops.Join)
def join(op, **kw):
left = translate(op.left, **kw)
right = translate(op.right, **kw)

right_table = op.right
if isinstance(op, ops.RightJoin):
how = "left"
right_table = op.left
left, right = right, left
else:
how = _join_types[type(op)]

left_cols = set(left.schema().names)
right_cols = {}
for col in right.schema().names:
if col in left_cols:
right_cols[col] = f"{col}_right"
else:
right_cols[col] = f"{col}"

left_keys, right_keys = [], []
for pred in op.predicates:
if isinstance(pred, ops.Equals):
left_keys.append(f'"{pred.left.name}"')
right_keys.append(f'"{right_cols[pred.right.name]}"')
else:
raise com.TranslationError(
"DataFusion backend is unable to compile join predicate "
f"with operation type of {type(pred)}"
)

right = translate(
ops.Selection(
right_table,
[
ops.Alias(ops.TableColumn(right_table, key), value)
for key, value in right_cols.items()
],
),
**kw,
)

return left.join(right, join_keys=(left_keys, right_keys), how=how)
1 change: 0 additions & 1 deletion ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,6 @@ def test_topk_op(alltypes, df):
)
],
)
@mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
@mark.broken(
["bigquery"],
raises=GoogleBadRequest,
Expand Down
21 changes: 8 additions & 13 deletions ibis/backends/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ def check_eq(left, right, how, **kwargs):
# syntax, but we might be able to work around that using
# LEFT JOIN UNION RIGHT JOIN
marks=pytest.mark.notimpl(
["mysql"]
["mysql", "datafusion"]
+ (["sqlite"] * (vparse(sqlite3.sqlite_version) < vparse("3.39")))
),
),
],
)
@pytest.mark.notimpl(["datafusion", "druid"])
@pytest.mark.notimpl(["druid"])
@pytest.mark.xfail_version(
polars=["polars>=0.18.6,<0.18.8"],
reason="https://github.com/pola-rs/polars/issues/9955",
Expand Down Expand Up @@ -119,7 +119,7 @@ def test_mutating_join(backend, batting, awards_players, how):


@pytest.mark.parametrize("how", ["semi", "anti"])
@pytest.mark.notimpl(["bigquery", "dask", "datafusion", "druid"])
@pytest.mark.notimpl(["bigquery", "dask", "druid"])
def test_filtering_join(backend, batting, awards_players, how):
left = batting[batting.yearID == 2015]
right = awards_players[awards_players.lgID == "NL"].drop("yearID", "lgID")
Expand Down Expand Up @@ -148,7 +148,6 @@ def test_filtering_join(backend, batting, awards_players, how):
backend.assert_frame_equal(result, expected, check_like=True)


@pytest.mark.notimpl(["datafusion"])
@pytest.mark.broken(
["polars"],
raises=ValueError,
Expand All @@ -166,7 +165,6 @@ def test_join_then_filter_no_column_overlap(awards_players, batting):
assert not q.execute().empty


@pytest.mark.notimpl(["datafusion"])
@pytest.mark.broken(
["polars"],
raises=ValueError,
Expand All @@ -180,7 +178,7 @@ def test_mutate_then_join_no_column_overlap(batting, awards_players):
assert not expr.limit(5).execute().empty


@pytest.mark.notimpl(["datafusion", "bigquery", "druid"])
@pytest.mark.notimpl(["bigquery", "druid"])
@pytest.mark.notyet(["dask"], reason="dask doesn't support descending order by")
@pytest.mark.broken(
["polars"],
Expand All @@ -205,7 +203,7 @@ def test_semi_join_topk(batting, awards_players, func):
assert not expr.limit(5).execute().empty


@pytest.mark.notimpl(["dask", "datafusion", "druid"])
@pytest.mark.notimpl(["dask", "druid"])
def test_join_with_pandas(batting, awards_players):
batting_filt = batting[lambda t: t.yearID < 1900]
awards_players_filt = awards_players[lambda t: t.yearID < 1900].execute()
Expand All @@ -215,14 +213,14 @@ def test_join_with_pandas(batting, awards_players):
assert df.yearID.nunique() == 7


@pytest.mark.notimpl(["dask", "datafusion"])
@pytest.mark.notimpl(["dask"])
def test_join_with_pandas_non_null_typed_columns(batting, awards_players):
batting_filt = batting[lambda t: t.yearID < 1900][["yearID"]]
awards_players_filt = awards_players[lambda t: t.yearID < 1900][
["yearID"]
].execute()

# ensure that none of the columns of eitherr table have type null
# ensure that none of the columns of either table have type null
batting_schema = batting_filt.schema()
assert len(batting_schema) == 1
assert batting_schema["yearID"].is_integer()
Expand Down Expand Up @@ -297,10 +295,7 @@ def test_join_with_pandas_non_null_typed_columns(batting, awards_players):
],
)
@pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError, reason="joins not implemented"
)
@pytest.mark.notimpl(
["polars"],
["polars", "datafusion"],
raises=com.TranslationError,
reason="polars doesn't support join predicates",
)
Expand Down

0 comments on commit e2c143a

Please sign in to comment.