Skip to content

Commit

Permalink
feat(api): support boolean literals in join API
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 31, 2023
1 parent b457c7b commit c56376f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 2 deletions.
102 changes: 102 additions & 0 deletions ibis/backends/tests/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,24 @@
import numpy as np
import pandas as pd
import pytest
import sqlalchemy as sa
from packaging.version import parse as vparse
from pytest import param

import ibis
import ibis.common.exceptions as com
import ibis.expr.schema as sch

try:
from polars.exceptions import ColumnNotFoundError
except ImportError:
ColumnNotFoundError = None

try:
from impala.error import HiveServer2Error
except ImportError:
HiveServer2Error = None


def _pandas_semi_join(left, right, on, **_):
assert len(on) == 1, str(on)
Expand Down Expand Up @@ -224,3 +232,97 @@ def test_join_with_pandas_non_null_typed_columns(batting, awards_players):
expr = batting_filt.join(awards_players_filt, "yearID")
df = expr.execute()
assert df.yearID.nunique() == 7


@pytest.mark.parametrize(
("predicate", "pandas_value"),
[
# Trues
param(True, True, id="true"),
param(ibis.literal(True), True, id="true-literal"),
param([True], True, id="true-list"),
param([ibis.literal(True)], True, id="true-literal-list"),
# only trues
param([True, True], True, id="true-true-list"),
param(
[ibis.literal(True), ibis.literal(True)], True, id="true-true-literal-list"
),
param([True, ibis.literal(True)], True, id="true-true-const-expr-list"),
param([ibis.literal(True), True], True, id="true-true-expr-const-list"),
# Falses
param(False, False, id="false"),
param(ibis.literal(False), False, id="false-literal"),
param([False], False, id="false-list"),
param([ibis.literal(False)], False, id="false-literal-list"),
# only falses
param([False, False], False, id="false-false-list"),
param(
[ibis.literal(False), ibis.literal(False)],
False,
id="false-false-literal-list",
),
param([False, ibis.literal(False)], False, id="false-false-const-expr-list"),
param([ibis.literal(False), False], False, id="false-false-expr-const-list"),
],
)
@pytest.mark.parametrize(
"how",
[
"inner",
"left",
"right",
param(
"outer",
marks=[
pytest.mark.notyet(
["mysql"],
raises=sa.exc.ProgrammingError,
reason="MySQL doesn't support full outer joins natively",
),
pytest.mark.notyet(
["impala"],
raises=HiveServer2Error,
reason=(
"impala doesn't support full outer joins with non-equi-join "
"predicates"
),
),
pytest.mark.notyet(
["sqlite"],
condition=vparse(sqlite3.sqlite_version) < vparse("3.39"),
reason="sqlite didn't support full outer join until 3.39",
),
],
),
],
)
@pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError, reason="joins not implemented"
)
@pytest.mark.notimpl(
["polars"],
raises=com.TranslationError,
reason="polars doesn't support join predicates",
)
@pytest.mark.notimpl(
["dask", "pandas"],
raises=TypeError,
reason="dask and pandas don't support join predicates",
)
def test_join_with_trivial_predicate(awards_players, predicate, how, pandas_value):
n = 5

base = awards_players.limit(n)

left = base.select(left_key="playerID")
right = base.select(right_key="playerID")

left_df = pd.DataFrame({"key": [True] * n})
right_df = pd.DataFrame({"key": [pandas_value] * n})

expected = pd.merge(left_df, right_df, on="key", how=how)

expr = left.join(right, predicate, how=how)
result = expr.to_pandas()

assert len(result) == len(expected)
6 changes: 4 additions & 2 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ def _clean_join_predicates(left, right, predicates):
pred = lk == rk
elif isinstance(pred, str):
pred = left.to_expr()[pred] == right.to_expr()[pred]
elif pred is True or pred is False:
pred = ops.Literal(pred, dtype="bool").to_expr()
elif isinstance(pred, Value):
pred = pred.to_expr()
elif isinstance(pred, Deferred):
Expand All @@ -177,8 +179,8 @@ def _clean_join_predicates(left, right, predicates):
elif not isinstance(pred, ir.Expr):
raise NotImplementedError

if not isinstance(pred, ir.BooleanColumn):
raise com.ExpressionError('Join predicate must be comparison')
if not isinstance(pred, ir.BooleanValue):
raise com.ExpressionError('Join predicate must be a boolean expression')

preds = an.flatten_predicate(pred.op())
result.extend(preds)
Expand Down

0 comments on commit c56376f

Please sign in to comment.