Skip to content

Commit

Permalink
fix(pyspark): enable joining on columns with different names as well …
Browse files Browse the repository at this point in the history
…as complex predicates
  • Loading branch information
cpcloud authored and kszucs committed May 11, 2023
1 parent 20f3011 commit dcee821
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 22 deletions.
14 changes: 5 additions & 9 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import enum
import functools
import operator
from functools import partial, reduce

import pyspark
import pyspark.sql.functions as F
Expand Down Expand Up @@ -396,7 +397,7 @@ def compile_aggregation(t, op, **kwargs):
)

if op.predicates:
predicate = functools.reduce(ops.And, op.predicates)
predicate = reduce(ops.And, op.predicates)
src_table = src_table.filter(t.translate(predicate, **kwargs))

if op.by:
Expand Down Expand Up @@ -1121,14 +1122,9 @@ def compile_join(t, op, how, **kwargs):
left_df = t.translate(op.left, **kwargs)
right_df = t.translate(op.right, **kwargs)

pred_columns = []
for pred in op.predicates:
if not isinstance(pred, ops.Equals):
raise NotImplementedError(
f"Only equality predicate is supported, but got {type(pred)}"
)
pred_columns.append(pred.left.name)

pred_columns = reduce(
operator.and_, map(partial(t.translate, **kwargs), op.predicates)
)
return left_df.join(right_df, pred_columns, how)


Expand Down
14 changes: 1 addition & 13 deletions ibis/backends/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,7 @@ def test_filtering_join(backend, batting, awards_players, how):
backend.assert_frame_equal(result, expected, check_like=True)


@pytest.mark.notyet(
["pyspark"],
reason="pyspark doesn't support joining on differing column names",
)
@pytest.mark.notimpl(["datafusion", "pyspark"])
@pytest.mark.notimpl(["datafusion"])
def test_join_then_filter_no_column_overlap(awards_players, batting):
left = batting[batting.yearID == 2015]
year = left.yearID.name("year")
Expand All @@ -174,10 +170,6 @@ def test_join_then_filter_no_column_overlap(awards_players, batting):


@pytest.mark.notimpl(["datafusion"])
@pytest.mark.notyet(
["pyspark"],
reason="pyspark doesn't support joining on differing column names",
)
def test_mutate_then_join_no_column_overlap(batting, awards_players):
left = batting.mutate(year=batting.yearID).filter(lambda t: t.year == 2015)
left = left["year", "RBI"]
Expand All @@ -187,10 +179,6 @@ def test_mutate_then_join_no_column_overlap(batting, awards_players):


@pytest.mark.notimpl(["datafusion", "bigquery", "druid"])
@pytest.mark.notyet(
["pyspark"],
reason="pyspark doesn't support joining on differing column names",
)
@pytest.mark.notyet(["dask"], reason="dask doesn't support descending order by")
def test_semi_join_topk(batting, awards_players):
batting = batting.mutate(year=batting.yearID)
Expand Down

0 comments on commit dcee821

Please sign in to comment.