Skip to content

Commit

Permalink
feat(api): implement isin
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Feb 6, 2023
1 parent e16b91f commit ac31db2
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 44 deletions.
33 changes: 8 additions & 25 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,8 +810,8 @@ def _string_contains(op, **kw):


def contains(op_string: Literal["IN", "NOT IN"]) -> str:
def translate(op, *, cache, **kw):
import ibis.expr.analysis as an
def tr(op, *, cache, **kw):
from ibis.backends.clickhouse.compiler import translate

value = op.value
options = op.options
Expand All @@ -827,36 +827,19 @@ def translate(op, *, cache, **kw):
not isinstance(options, tuple)
and options.output_shape is rlz.Shape.COLUMNAR
):
leaves = list(an.find_immediate_parent_tables(options))
nleaves = len(leaves)
if nleaves > 1:
raise NotImplementedError(
"more than one leaf table in a NOT IN/IN query unsupported"
)
(leaf,) = leaves

shared_roots_count = sum(
an.shares_all_roots(value, child)
for child in an.find_immediate_parent_tables(options)
)
if shared_roots_count == nleaves:
from ibis.backends.clickhouse.compiler.relations import translate_rel

op = options.to_expr().as_table().op()
subquery = translate_rel(op, table=cache[leaf], **kw)
right_arg = f"({subquery})"
else:
raise NotImplementedError(
"ClickHouse doesn't support correlated subqueries"
)
# this will fail to execute if there's a correlation, but it's too
# annoying to detect so we let it through to enable the
# uncorrelated use case (pandas-style `.isin`)
subquery = translate(options.to_expr().as_table().op(), {})
right_arg = f"({subquery})"
else:
right_arg = translate_val(options, cache=cache, **kw)

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for Sequence operations
return f"{left_arg} {op_string} {right_arg}"

return translate
return tr


translate_val.register(ops.Contains)(contains("IN"))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,29 @@
WITH t0 AS (
SELECT t3.`userid`, t3.`movieid`, t3.`rating`,
CAST(t3.`timestamp` AS timestamp) AS `datetime`
FROM ratings t3
SELECT t2.`userid`, t2.`movieid`, t2.`rating`,
CAST(t2.`timestamp` AS timestamp) AS `datetime`
FROM ratings t2
),
t1 AS (
SELECT t0.*, t4.`title`
SELECT t0.*, t3.`title`
FROM t0
INNER JOIN movies t4
ON t0.`movieid` = t4.`movieid`
INNER JOIN movies t3
ON t0.`movieid` = t3.`movieid`
)
SELECT t2.*
FROM (
SELECT t1.*
FROM t1
WHERE (t1.`userid` = 118205) AND
(extract(t1.`datetime`, 'year') > 2001)
) t2
WHERE t2.`movieid` IN (
SELECT t3.`movieid`
SELECT t1.*
FROM t1
WHERE (t1.`userid` = 118205) AND
(extract(t1.`datetime`, 'year') > 2001) AND
(t1.`movieid` IN (
SELECT t2.`movieid`
FROM (
SELECT t4.`movieid`
SELECT t3.`movieid`
FROM (
SELECT t1.*
FROM t1
WHERE (t1.`userid` = 118205) AND
(extract(t1.`datetime`, 'year') > 2001) AND
(t1.`userid` = 118205) AND
(extract(t1.`datetime`, 'year') < 2009)
) t4
) t3
)
) t3
) t2
))
43 changes: 43 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,3 +839,46 @@ def test_typeof(backend, con):
result = con.execute(expr)

assert result is not None


@pytest.mark.broken(["polars"], reason="incorrect answer")
@pytest.mark.notimpl(["datafusion", "bigquery", "impala", "pyspark"])
@pytest.mark.notyet(["dask", "mssql"], reason="not supported by the backend")
def test_isin_uncorrelated(
backend, batting, awards_players, batting_df, awards_players_df
):
expr = batting.select(
"playerID",
"yearID",
x=batting.yearID.isin(awards_players.yearID),
).order_by(["playerID", "yearID"])
result = expr.execute().x
expected = (
batting_df.sort_values(["playerID", "yearID"])
.reset_index(drop=True)
.yearID.isin(awards_players_df.yearID)
.rename("x")
)
backend.assert_series_equal(result, expected)


@pytest.mark.broken(["polars"], reason="incorrect answer")
@pytest.mark.notimpl(["datafusion", "pyspark"])
@pytest.mark.notyet(["dask"], reason="not supported by the backend")
def test_isin_uncorrelated_filter(
backend, batting, awards_players, batting_df, awards_players_df
):
expr = (
batting.select("playerID", "yearID")
.filter(batting.yearID.isin(awards_players.yearID))
.order_by(["playerID", "yearID"])
)
result = expr.execute()
expected = (
batting_df.loc[
batting_df.yearID.isin(awards_players_df.yearID), ["playerID", "yearID"]
]
.sort_values(["playerID", "yearID"])
.reset_index(drop=True)
)
backend.assert_frame_equal(result, expected)
8 changes: 8 additions & 0 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def finder(node):
return g.halt, node
else:
return g.proceed, None

# HACK: special case ops.Contains to only consider the needle's base
# table, since that's the only expression that matters for determining
# cardinality
elif isinstance(node, ops.Contains):
return [node.value], None
else:
return g.proceed, None

Expand Down Expand Up @@ -646,6 +652,8 @@ def _find_projections(node):
return g.proceed, None
elif isinstance(node, ops.TableNode):
return g.halt, node
elif isinstance(node, ops.Contains):
return [node.value], None
else:
return g.proceed, None

Expand Down

0 comments on commit ac31db2

Please sign in to comment.