Skip to content

Commit

Permalink
feat(backends): allow column expressions from non-foreign tables on t…
Browse files Browse the repository at this point in the history
…he right side of `isin`/`notin`
  • Loading branch information
cpcloud committed May 17, 2022
1 parent 29ee19a commit e1374a4
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 52 deletions.
33 changes: 23 additions & 10 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,22 +170,35 @@ def _cast(t, expr):
sa_arg = t.translate(arg)
sa_type = t.get_sqla_type(target_type)

if isinstance(arg, ir.CategoryValue) and target_type == 'int32':
if isinstance(arg, ir.CategoryValue) and target_type == dt.int32:
return sa_arg
else:
return sa.cast(sa_arg, sa_type)


def _contains(t, expr):
op = expr.op()

left, right = (t.translate(arg) for arg in op.args)
def _contains(func):
def translate(t, expr):
op = expr.op()

return left.in_(right)
raw_left, raw_right = op.args
left = t.translate(raw_left)
right = t.translate(raw_right)

if (
# not a list expr
not isinstance(raw_right, ir.ListExpr)
# but still a column expr
and isinstance(raw_right, ir.ColumnExpr)
# wasn't already compiled into a select statement
and not isinstance(right, sa.sql.Selectable)
):
right = sa.select(right)
else:
right = t.translate(raw_right)

return func(left, right)

def _not_contains(t, expr):
return sa.not_(_contains(t, expr))
return translate


def reduction(sa_func):
Expand Down Expand Up @@ -462,8 +475,8 @@ def _string_join(t, expr):
ops.Cast: _cast,
ops.Coalesce: varargs(sa.func.coalesce),
ops.NullIf: fixed_arity(sa.func.nullif, 2),
ops.Contains: _contains,
ops.NotContains: _not_contains,
ops.Contains: _contains(lambda left, right: left.in_(right)),
ops.NotContains: _contains(lambda left, right: left.notin_(right)),
ops.Count: reduction(sa.func.count),
ops.Sum: reduction(sa.func.sum),
ops.Mean: reduction(sa.func.avg),
Expand Down
67 changes: 40 additions & 27 deletions ibis/backends/base/sql/registry/binary_infix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Literal

import ibis.expr.types as ir
from ibis.backends.base.sql.registry import helpers

Expand Down Expand Up @@ -53,34 +55,45 @@ def xor(translator, expr):
return '({0} OR {1}) AND NOT ({0} AND {1})'.format(left_arg, right_arg)


def isin(translator, expr):
op = expr.op()

left, right = op.args
if isinstance(right, ir.ListExpr) and not right:
return "FALSE"

left_arg = translator.translate(left)
right_arg = translator.translate(right)
if helpers.needs_parens(left):
left_arg = helpers.parenthesize(left_arg)

# we explicitly do NOT parenthesize the right side because it doesn't make
# sense to do so for ValueList operations

return f"{left_arg} IN {right_arg}"

def contains(op_string: Literal["IN", "NOT IN"]) -> str:
def translate(translator, expr):
from ibis.backends.base.sql.registry.main import table_array_view

def notin(translator, expr):
op = expr.op()
op = expr.op()

left, right = op.args
if isinstance(right, ir.ListExpr) and not right:
return "TRUE"
left, right = op.args
if isinstance(right, ir.ListExpr) and not right:
return {"NOT IN": "TRUE", "IN": "FALSE"}[op_string]

left_arg = translator.translate(left)
right_arg = translator.translate(right)
if helpers.needs_parens(left):
left_arg = helpers.parenthesize(left_arg)
left_arg = translator.translate(left)
if helpers.needs_parens(left):
left_arg = helpers.parenthesize(left_arg)

return f"{left_arg} NOT IN {right_arg}"
ctx = translator.context

# special case non-foreign isin/notin expressions
if (
not isinstance(right, ir.ListExpr)
and isinstance(right, ir.ColumnExpr)
# foreign refs are already been compiled correctly during
# TableColumn compilation
and not any(
ctx.is_foreign_expr(leaf.to_expr())
for leaf in right.op().root_tables()
)
):
if not right.has_name():
right = right.name("tmp")
right_arg = table_array_view(
translator,
right.to_projection().to_array(),
)
else:
right_arg = translator.translate(right)

# we explicitly do NOT parenthesize the right side because it doesn't
# make sense to do so for ValueList operations

return f"{left_arg} {op_string} {right_arg}"

return translate
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ def hash(translator, expr):
ops.Least: varargs('least'),
ops.Where: fixed_arity('if', 3),
ops.Between: between,
ops.Contains: binary_infix.isin,
ops.NotContains: binary_infix.notin,
ops.Contains: binary_infix.contains("IN"),
ops.NotContains: binary_infix.contains("NOT IN"),
ops.SimpleCase: case.simple_case,
ops.SearchedCase: case.searched_case,
ops.TableColumn: table_column,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/clickhouse/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,8 +715,8 @@ def _string_right(translator, expr):
ops.DateAdd: binary_infix.binary_infix_op('+'),
ops.DateSub: binary_infix.binary_infix_op('-'),
ops.DateDiff: binary_infix.binary_infix_op('-'),
ops.Contains: binary_infix.isin,
ops.NotContains: binary_infix.notin,
ops.Contains: binary_infix.contains("IN"),
ops.NotContains: binary_infix.contains("NOT IN"),
ops.TimestampAdd: binary_infix.binary_infix_op('+'),
ops.TimestampSub: binary_infix.binary_infix_op('-'),
ops.TimestampDiff: binary_infix.binary_infix_op('-'),
Expand Down
10 changes: 8 additions & 2 deletions ibis/backends/dask/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,19 @@
],
ops.Contains: [
(
(dd.Series, (collections.abc.Sequence, collections.abc.Set)),
(
dd.Series,
(collections.abc.Sequence, collections.abc.Set, dd.Series),
),
execute_node_contains_series_sequence,
)
],
ops.NotContains: [
(
(dd.Series, (collections.abc.Sequence, collections.abc.Set)),
(
dd.Series,
(collections.abc.Sequence, collections.abc.Set, dd.Series),
),
execute_node_not_contains_series_sequence,
)
],
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,14 +870,18 @@ def execute_node_string_join(op, args, **kwargs):


@execute_node.register(
ops.Contains, pd.Series, (collections.abc.Sequence, collections.abc.Set)
ops.Contains,
pd.Series,
(collections.abc.Sequence, collections.abc.Set, pd.Series),
)
def execute_node_contains_series_sequence(op, data, elements, **kwargs):
return data.isin(elements)


@execute_node.register(
ops.NotContains, pd.Series, (collections.abc.Sequence, collections.abc.Set)
ops.NotContains,
pd.Series,
(collections.abc.Sequence, collections.abc.Set, pd.Series),
)
def execute_node_not_contains_series_sequence(op, data, elements, **kwargs):
return ~(data.isin(elements))
Expand Down
47 changes: 44 additions & 3 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
import pandas as pd
import pandas.testing as tm
import pytest
from packaging.version import parse as vparse
from pytest import param
Expand Down Expand Up @@ -453,8 +452,50 @@ def test_table_info(alltypes):
),
],
)
def test_isin_notin(alltypes, df, ibis_op, pandas_op):
def test_isin_notin(backend, alltypes, df, ibis_op, pandas_op):
expr = alltypes[ibis_op]
expected = df.loc[pandas_op(df)].sort_values(["id"]).reset_index(drop=True)
result = expr.execute().sort_values(["id"]).reset_index(drop=True)
tm.assert_frame_equal(result, expected, check_index_type=False)
backend.assert_frame_equal(result, expected)


@pytest.mark.notyet(
["dask"],
reason="dask doesn't support Series as isin/notin argument",
raises=NotImplementedError,
)
@pytest.mark.notimpl(["datafusion"])
@pytest.mark.parametrize(
("ibis_op", "pandas_op"),
[
param(
_.string_col.isin(_.string_col),
lambda df: df.string_col.isin(df.string_col),
id="isin_col",
),
param(
(_.bigint_col + 1).isin(_.string_col.cast("int64") + 1),
lambda df: (df.bigint_col + 1).isin(
df.string_col.astype("int64") + 1
),
id="isin_expr",
),
param(
_.string_col.notin(_.string_col),
lambda df: ~df.string_col.isin(df.string_col),
id="notin_col",
),
param(
(_.bigint_col + 1).notin(_.string_col.cast("int64") + 1),
lambda df: ~(df.bigint_col + 1).isin(
df.string_col.astype("int64") + 1
),
id="notin_expr",
),
],
)
def test_isin_notin_column_expr(backend, alltypes, df, ibis_op, pandas_op):
expr = alltypes[ibis_op].sort_by("id")
expected = df[pandas_op(df)].sort_values(["id"]).reset_index(drop=True)
result = expr.execute()
backend.assert_frame_equal(result, expected)
65 changes: 64 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ pydocstyle = ">=6.1.1,<7"
pymdown-extensions = ">=9.1,<9.4"
pytest = ">=7.0.0,<8"
pytest-benchmark = ">=3.4.1,<4"
pytest-clarity = ">=1.0.1,<2"
pytest-cov = ">=3.0.0,<4"
pytest-mock = ">=3.6.1,<4"
pytest-profiling = ">=1.7.0,<2"
Expand Down
Loading

0 comments on commit e1374a4

Please sign in to comment.