Skip to content

Commit

Permalink
refactor(ir): remove ops.Negatable, ops.NotAny, ops.NotAll, ops.Unres…
Browse files Browse the repository at this point in the history
…olvedNotExistsSubquery
  • Loading branch information
kszucs authored and cpcloud committed Oct 13, 2023
1 parent 462bd17 commit e31e8fd
Show file tree
Hide file tree
Showing 24 changed files with 41 additions and 350 deletions.
4 changes: 0 additions & 4 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,6 @@ def _exists_subquery(t, op):
sub_ctx = ctx.subcontext()
clause = ctx.compiler.to_sql(filtered, sub_ctx, exists=True)

if isinstance(op, ops.NotExistsSubquery):
clause = sa.not_(clause)

return clause


Expand Down Expand Up @@ -563,7 +560,6 @@ class array_filter(FunctionElement):
ops.TableColumn: _table_column,
ops.TableArrayView: _table_array_view,
ops.ExistsSubquery: _exists_subquery,
ops.NotExistsSubquery: _exists_subquery,
# miscellaneous varargs
ops.Least: varargs(sa.func.least),
ops.Greatest: varargs(sa.func.greatest),
Expand Down
12 changes: 0 additions & 12 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,23 +336,11 @@ def _any_expand(op):
return ops.Max(op.arg, where=op.where)


@rewrites(ops.NotAny)
def _notany_expand(op):
zero = ops.Literal(0, dtype=op.arg.dtype)
return ops.Min(ops.Equals(op.arg, zero), where=op.where)


@rewrites(ops.All)
def _all_expand(op):
return ops.Min(op.arg, where=op.where)


@rewrites(ops.NotAll)
def _notall_expand(op):
zero = ops.Literal(0, dtype=op.arg.dtype)
return ops.Max(ops.Equals(op.arg, zero), where=op.where)


@rewrites(ops.Cast)
def _rewrite_cast(op):
# TODO(kszucs): avoid the expression roundtrip
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ def exists_subquery(translator, op):

subquery = ctx.get_compiled_expr(node)

prefix = "NOT " * isinstance(op, ops.NotExistsSubquery)
return f"{prefix}EXISTS (\n{util.indent(subquery, ctx.indent)}\n)"
return f"EXISTS (\n{util.indent(subquery, ctx.indent)}\n)"


# XXX this is not added to operation_registry, but looks like impala is
Expand Down Expand Up @@ -350,7 +349,6 @@ def count_star(translator, op):
ops.TimestampDiff: timestamp.timestamp_diff,
ops.TimestampFromUNIX: timestamp.timestamp_from_unix,
ops.ExistsSubquery: exists_subquery,
ops.NotExistsSubquery: exists_subquery,
# RowNumber, and rank functions starts with 0 in Ibis-land
ops.RowNumber: lambda *_: "row_number()",
ops.DenseRank: lambda *_: "dense_rank()",
Expand Down
10 changes: 0 additions & 10 deletions ibis/backends/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,6 @@ def _trans_param(self, op):
compiles = BigQueryExprTranslator.compiles


@BigQueryExprTranslator.rewrites(ops.NotAll)
def _rewrite_notall(op):
return ops.Any(ops.Not(op.arg), where=op.where)


@BigQueryExprTranslator.rewrites(ops.NotAny)
def _rewrite_notany(op):
return ops.All(ops.Not(op.arg), where=op.where)


class BigQueryTableSetFormatter(sql_compiler.TableSetFormatter):
def _quote_identifier(self, name):
return sg.to_identifier(name).sql("bigquery")
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/bigquery/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,4 @@ def bq_mean(op):
ops.Mean: bq_mean,
ops.Any: toolz.identity,
ops.All: toolz.identity,
ops.NotAny: toolz.identity,
ops.NotAll: toolz.identity,
}
15 changes: 0 additions & 15 deletions ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,6 @@ def fn(node, _, **kwargs):
False, dtype="bool"
)

# replace `NotExistsSubquery` with `Not(ExistsSubquery)`
#
# this allows to avoid having another rule to negate ExistsSubquery
replace_notexists_subquery_with_not_exists = p.NotExistsSubquery(...) >> c.Not(
c.ExistsSubquery(...)
)

# clickhouse-specific rewrite to turn notany/notall into equivalent
# already-defined operations
replace_notany_with_min_not = p.NotAny(x, where=y) >> c.Min(c.Not(x), where=y)
replace_notall_with_max_not = p.NotAll(x, where=y) >> c.Max(c.Not(x), where=y)

# subtract one from ranking functions to convert from 1-indexed to 0-indexed
subtract_one_from_ranking_functions = p.WindowFunction(
p.RankBase | p.NTile
Expand All @@ -124,9 +112,6 @@ def fn(node, _, **kwargs):
replace_literals
| replace_in_column_with_table_array_view
| replace_empty_in_values_with_false
| replace_notexists_subquery_with_not_exists
| replace_notany_with_min_not
| replace_notall_with_max_not
| subtract_one_from_ranking_functions
| add_one_to_nth_value_input
)
Expand Down
33 changes: 0 additions & 33 deletions ibis/backends/dask/execution/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
- ops.Aggregation
- ops.Any
- ops.NotAny
- ops.All
- ops.NotAll
"""

from __future__ import annotations
Expand Down Expand Up @@ -132,33 +129,3 @@ def execute_any_all_series_group_by(op, data, mask, aggcontext=None, **kwargs):
# here for future scaffolding.
result = aggcontext.agg(data, operator.methodcaller(name))
return result


@execute_node.register((ops.NotAny, ops.NotAll), dd.Series, (dd.Series, type(None)))
def execute_notany_series(op, data, mask, aggcontext=None, **kwargs):
if mask is not None:
data = data.loc[mask]

name = type(op).__name__[len("Not") :].lower()
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
# Note this branch is not currently hit in the dask backend but is
# here for future scaffolding.
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
return result


@execute_node.register((ops.NotAny, ops.NotAll), ddgb.SeriesGroupBy, type(None))
def execute_notany_series_group_by(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__[len("Not") :].lower()
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
# Note this branch is not currently hit in the dask backend but is
# here for future scaffolding.
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))

return result
2 changes: 0 additions & 2 deletions ibis/backends/duckdb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ def compile_array(element, compiler, **kw):

@rewrites(ops.Any)
@rewrites(ops.All)
@rewrites(ops.NotAny)
@rewrites(ops.NotAll)
@rewrites(ops.StringContains)
def _no_op(expr):
return expr
Expand Down
32 changes: 0 additions & 32 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,38 +847,6 @@ def execute_any_all_series_group_by(op, data, mask, aggcontext=None, **kwargs):
return result


@execute_node.register((ops.NotAny, ops.NotAll), pd.Series, (pd.Series, type(None)))
def execute_notany_notall_series(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__.lower()[len("Not") :]
if mask is not None:
data = data.loc[mask]
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
try:
return result.astype(bool)
except TypeError:
return result


@execute_node.register((ops.NotAny, ops.NotAll), SeriesGroupBy, type(None))
def execute_notany_notall_series_group_by(op, data, mask, aggcontext=None, **kwargs):
name = type(op).__name__.lower()[len("Not") :]
if mask is not None:
data = data.obj.loc[mask].groupby(get_grouping(data.grouper.groupings))
if isinstance(aggcontext, (agg_ctx.Summarize, agg_ctx.Transform)):
result = ~aggcontext.agg(data, name)
else:
method = operator.methodcaller(name)
result = aggcontext.agg(data, lambda data: ~method(data))
try:
return result.astype(bool)
except TypeError:
return result


@execute_node.register(ops.CountStar, pd.DataFrame, type(None))
def execute_count_star_frame(op, data, _, **kwargs):
return len(data)
Expand Down
18 changes: 0 additions & 18 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,24 +1058,6 @@ def execute_hash(op, **kw):
return translate(op.arg, **kw).hash()


@translate.register(ops.NotAll)
def execute_not_all(op, **kw):
arg = op.arg
if (op_where := op.where) is not None:
arg = ops.IfElse(op_where, arg, None)

return translate(arg, **kw).all().not_()


@translate.register(ops.NotAny)
def execute_not_any(op, **kw):
arg = op.arg
if (op_where := op.where) is not None:
arg = ops.IfElse(op_where, arg, None)

return translate(arg, **kw).any().not_()


def _arg_min_max(op, func, **kw):
key = op.key
arg = op.arg
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class PostgreSQLExprTranslator(AlchemyExprTranslator):

@rewrites(ops.Any)
@rewrites(ops.All)
@rewrites(ops.NotAny)
@rewrites(ops.NotAll)
def _any_all_no_op(expr):
return expr

Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,6 @@ def _array_filter(t, op):
# boolean reductions
ops.Any: reduction(sa.func.bool_or),
ops.All: reduction(sa.func.bool_and),
ops.NotAny: reduction(lambda x: sa.func.bool_and(~x)),
ops.NotAll: reduction(lambda x: sa.func.bool_or(~x)),
# strings
ops.GroupConcat: _string_agg,
ops.Capitalize: unary(sa.func.initcap),
Expand Down
22 changes: 1 addition & 21 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,21 +554,11 @@ def compile_any(t, op, *, aggcontext=None, **kwargs):
return compile_aggregator(t, op, fn=F.max, aggcontext=aggcontext, **kwargs)


@compiles(ops.NotAny)
def compile_notany(t, op, *args, aggcontext=None, **kwargs):
return ~compile_any(t, op, *args, aggcontext=aggcontext, **kwargs)


@compiles(ops.All)
def compile_all(t, op, *args, **kwargs):
return compile_aggregator(t, op, *args, fn=F.min, **kwargs)


@compiles(ops.NotAll)
def compile_notall(t, op, *, aggcontext=None, **kwargs):
return ~compile_all(t, op, aggcontext=aggcontext, **kwargs)


@compiles(ops.Count)
def compile_count(t, op, **kwargs):
return compile_aggregator(t, op, fn=F.count, **kwargs)
Expand Down Expand Up @@ -1235,19 +1225,9 @@ def compile_window_function(t, op, **kwargs):
else:
pyspark_window = pyspark_window.rowsBetween(win_start, win_end)

orig_func = func
if isinstance(orig_func, (ops.NotAll, ops.NotAny)):
# For NotAll and NotAny, negation must be applied after .over(window)
# Here we rewrite node to be its negation, and negate it back after
# translation and window operation
func = func.negate()
else:
func = orig_func
result = t.translate(func, **kwargs, aggcontext=aggcontext).over(pyspark_window)

if isinstance(orig_func, (ops.NotAll, ops.NotAny)):
return ~result
elif isinstance(func, ops.RankBase):
if isinstance(func, ops.RankBase):
# result must be cast to long type for Rank / RowNumber
return result.astype("long") - 1
else:
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,6 @@ def _map_get(t, op):
ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.func.to_variant(arg))),
ops.All: reduction(sa.func.booland_agg),
ops.Any: reduction(sa.func.boolor_agg),
ops.NotAll: reduction(lambda arg: sa.func.boolor_agg(~arg)),
ops.NotAny: reduction(lambda arg: sa.func.booland_agg(~arg)),
ops.BitAnd: reduction(sa.func.bitand_agg),
ops.BitOr: reduction(sa.func.bitor_agg),
ops.BitXor: reduction(sa.func.bitxor_agg),
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def calc_zscore(s):
id="cumnotany",
marks=[
pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError),
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError),
pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError),
],
),
Expand Down Expand Up @@ -239,7 +239,7 @@ def calc_zscore(s):
id="cumnotall",
marks=[
pytest.mark.broken(["mssql"], raises=sa.exc.ProgrammingError),
pytest.mark.notimpl(["dask"], raises=NotImplementedError),
pytest.mark.notimpl(["dask"], raises=com.OperationNotDefinedError),
pytest.mark.broken(["oracle"], raises=sa.exc.DatabaseError),
],
),
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/trino/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class TrinoSQLExprTranslator(AlchemyExprTranslator):

@rewrites(ops.Any)
@rewrites(ops.All)
@rewrites(ops.NotAny)
@rewrites(ops.NotAll)
@rewrites(ops.StringContains)
def _no_op(expr):
return expr
Expand Down
2 changes: 0 additions & 2 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ def _array_intersect(t, op):
# boolean reductions
ops.Any: reduction(sa.func.bool_or),
ops.All: reduction(sa.func.bool_and),
ops.NotAny: reduction(lambda x: sa.func.bool_and(~x)),
ops.NotAll: reduction(lambda x: sa.func.bool_or(~x)),
ops.ArgMin: reduction(sa.func.min_by),
ops.ArgMax: reduction(sa.func.max_by),
# array ops
Expand Down
29 changes: 0 additions & 29 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,12 +625,6 @@ def predicate(node):
return any(g.traverse(predicate, node))


_ANY_OP_MAPPING = {
ops.Any: ops.UnresolvedExistsSubquery,
ops.NotAny: ops.UnresolvedNotExistsSubquery,
}


def find_predicates(node, flatten=True):
# TODO(kszucs): consider to remove flatten argument and compose with
# flatten_predicates instead
Expand Down Expand Up @@ -663,28 +657,6 @@ def find_subqueries(node: ops.Node, min_dependents=1) -> tuple[ops.Node, ...]:
)


# TODO(kszucs): move to types/logical.py
def _make_any(
expr,
any_op_class: type[ops.Any] | type[ops.NotAny],
*,
where: ir.BooleanValue | None = None,
):
assert isinstance(expr, ir.Expr), type(expr)

tables = find_immediate_parent_tables(expr.op())
predicates = find_predicates(expr.op(), flatten=True)

if len(tables) > 1:
op = _ANY_OP_MAPPING[any_op_class](
tables=[t.to_expr() for t in tables],
predicates=predicates,
)
else:
op = any_op_class(expr, where=expr._bind_reduction_filter(where))
return op.to_expr()


# TODO(kszucs): use substitute instead
@functools.singledispatch
def _rewrite_filter(op, **kwargs):
Expand All @@ -708,7 +680,6 @@ def _rewrite_filter_reduction(op, name: str | None = None, **kwargs):
@_rewrite_filter.register(ops.TableColumn)
@_rewrite_filter.register(ops.Literal)
@_rewrite_filter.register(ops.ExistsSubquery)
@_rewrite_filter.register(ops.NotExistsSubquery)
@_rewrite_filter.register(ops.WindowFunction)
def _rewrite_filter_subqueries(op, **kwargs):
"""Don't rewrite any of these operations in filters."""
Expand Down
Loading

0 comments on commit e31e8fd

Please sign in to comment.