Skip to content

Commit

Permalink
fix(ir): only dereference comparisons not generic binary operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Feb 12, 2024
1 parent 5fa088e commit 05ac73a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
42 changes: 42 additions & 0 deletions ibis/expr/tests/test_newrels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,3 +1314,45 @@ def test_join_method_docstrings():
join_method = getattr(joined, method)
table_method = getattr(t1, method)
assert join_method.__doc__ == table_method.__doc__


def test_join_with_compound_predicate():
t1 = ibis.table(name="t", schema={"a": "string", "b": "string"})
t2 = t1.view()

joined = t1.join(
t2,
[
t1.a == t2.a,
(t1.a != t2.b) | (t1.b != t2.a),
(t1.a != t2.b) ^ (t1.b != t2.a),
(t1.a != t2.b) & (t1.b != t2.a),
(t1.a + t1.a != t2.b) & (t1.b + t1.b != t2.a),
],
)
expr = joined[t1]
with join_tables(t1, t2) as (r1, r2):
expected = ops.JoinChain(
first=r1,
rest=[
ops.JoinLink(
"inner",
r2,
[
r1.a == r2.a,
(r1.a != r2.b) | (r1.b != r2.a),
(r1.a != r2.b) ^ (r1.b != r2.a),
# these are flattened
r1.a != r2.b,
r1.b != r2.a,
r1.a + r1.a != r2.b,
r1.b + r1.b != r2.a,
],
),
],
values={
"a": r1.a,
"b": r1.b,
},
)
assert expr.op() == expected
6 changes: 3 additions & 3 deletions ibis/expr/types/joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ def dereference_sides(left, right, deref_left, deref_right):
return left, right


def dereference_binop(pred, deref_left, deref_right):
def dereference_comparison_op(pred, deref_left, deref_right):
left, right = dereference_sides(pred.left, pred.right, deref_left, deref_right)
return pred.copy(left=left, right=right)


def dereference_value(pred, deref_left, deref_right):
deref_both = {**deref_left, **deref_right}
if isinstance(pred, ops.Binary) and pred.left.relations == pred.right.relations:
return dereference_binop(pred, deref_left, deref_right)
if isinstance(pred, ops.Comparison) and pred.left.relations == pred.right.relations:
return dereference_comparison_op(pred, deref_left, deref_right)
else:
return pred.replace(deref_both, filter=ops.Value)

Expand Down

0 comments on commit 05ac73a

Please sign in to comment.