Skip to content

Commit

Permalink
refactor: rename ops.Where to ops.IfElse
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 28, 2023
1 parent 995c1bc commit a64b7ad
Show file tree
Hide file tree
Showing 29 changed files with 92 additions and 92 deletions.
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def _(t, op):
func = ".".join(filter(None, (op.__udf_namespace__, op.__func_name__)))
args = ", ".join(
t.translate(
ops.Where(where, arg, NA)
ops.IfElse(where, arg, NA)
if (where := op.where) is not None
else arg
)
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/alchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def _(t, op):
elif t._has_reduction_filter_syntax:
return func(*map(t.translate, args)).filter(t.translate(where))
else:
return func(*(t.translate(ops.Where(where, arg, NA)) for arg in args))
return func(*(t.translate(ops.IfElse(where, arg, NA)) for arg in args))

def _register_udfs(self, expr: ir.Expr) -> None:
with self.begin() as con:
Expand Down
8 changes: 4 additions & 4 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def variance_compiler(t, op):
func = getattr(sa.func, f"{func_name}{suffix[op.how]}")

if op.where is not None:
arg = ops.Where(op.where, arg, None)
arg = ops.IfElse(op.where, arg, None)

return func(t.translate(arg))

Expand Down Expand Up @@ -478,7 +478,7 @@ def _count_star(t, op):
if t._has_reduction_filter_syntax:
return sa.func.count().filter(t.translate(where))

return sa.func.count(t.translate(ops.Where(where, 1, None)))
return sa.func.count(t.translate(ops.IfElse(where, 1, None)))


def _count_distinct_star(t, op):
Expand All @@ -502,7 +502,7 @@ def _count_distinct_star(t, op):
"filter with more than one column"
)

return sa.func.count(t.translate(ops.Where(op.where, sa.distinct(*cols), None)))
return sa.func.count(t.translate(ops.IfElse(op.where, sa.distinct(*cols), None)))


def _extract(fmt: str):
Expand Down Expand Up @@ -652,7 +652,7 @@ class array_filter(FunctionElement):
ops.IdenticalTo: fixed_arity(
sa.sql.expression.ColumnElement.is_not_distinct_from, 2
),
ops.Where: fixed_arity(
ops.IfElse: fixed_arity(
lambda predicate, value_if_true, value_if_false: sa.case(
(predicate, value_if_true),
else_=value_if_false,
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/alchemy/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _reduction(self, sa_func, op):
return sa_func(*sa_args).filter(self.translate(where))
else:
sa_args = tuple(
self.translate(ops.Where(where, arg, None)) for arg in argtuple
self.translate(ops.IfElse(where, arg, None)) for arg in argtuple
)
else:
sa_args = tuple(map(self.translate, argtuple))
Expand All @@ -135,7 +135,7 @@ def _reduction(self, sa_func, op):
def _nullifzero(op):
arg = op.arg
condition = ops.Equals(arg, ops.Literal(0, dtype=op.arg.dtype))
return ops.Where(condition, ibis.NA, arg)
return ops.IfElse(condition, ibis.NA, arg)


# TODO This was previously implemented with the legacy `@compiles` decorator.
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,14 @@ def _rewrite_clip(op):
if (upper := op.upper) is not None:
clipped_lower = ops.Least((arg, ops.Cast(upper, dtype)))
if dtype.nullable:
arg = ops.Where(arg_is_null, arg, clipped_lower)
arg = ops.IfElse(arg_is_null, arg, clipped_lower)
else:
arg = clipped_lower

if (lower := op.lower) is not None:
clipped_upper = ops.Greatest((arg, ops.Cast(lower, dtype)))
if dtype.nullable:
arg = ops.Where(arg_is_null, arg, clipped_upper)
arg = ops.IfElse(arg_is_null, arg, clipped_upper)
else:
arg = clipped_upper

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/base/sql/registry/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def _reduction_format(translator, func_name, where, arg, *args):
if where is not None:
arg = ops.Where(where, arg, ibis.NA)
arg = ops.IfElse(where, arg, ibis.NA)

return "{}({})".format(
func_name,
Expand Down Expand Up @@ -38,7 +38,7 @@ def formatter(translator, op):

def count_distinct(translator, op):
if op.where is not None:
arg_formatted = translator.translate(ops.Where(op.where, op.arg, None))
arg_formatted = translator.translate(ops.IfElse(op.where, op.arg, None))
else:
arg_formatted = translator.translate(op.arg)
return f"count(DISTINCT {arg_formatted})"
2 changes: 1 addition & 1 deletion ibis/backends/base/sql/registry/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def count_star(translator, op):
ops.Coalesce: varargs("coalesce"),
ops.Greatest: varargs("greatest"),
ops.Least: varargs("least"),
ops.Where: fixed_arity("if", 3),
ops.IfElse: fixed_arity("if", 3),
ops.Between: between,
ops.InValues: binary_infix.in_values,
ops.InColumn: binary_infix.in_column,
Expand Down
16 changes: 8 additions & 8 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def _arbitrary(translator, op):
arg, how, where = op.args

if where is not None:
arg = ops.Where(where, arg, ibis.NA)
arg = ops.IfElse(where, arg, ibis.NA)

if how != "first":
raise com.UnsupportedOperationError(
Expand All @@ -406,7 +406,7 @@ def _first(translator, op):
where = op.where

if where is not None:
arg = ops.Where(where, arg, ibis.NA)
arg = ops.IfElse(where, arg, ibis.NA)

arg = translator.translate(arg)
return f"ARRAY_AGG({arg} IGNORE NULLS)[SAFE_OFFSET(0)]"
Expand All @@ -417,7 +417,7 @@ def _last(translator, op):
where = op.where

if where is not None:
arg = ops.Where(where, arg, ibis.NA)
arg = ops.IfElse(where, arg, ibis.NA)

arg = translator.translate(arg)
return f"ARRAY_REVERSE(ARRAY_AGG({arg} IGNORE NULLS))[SAFE_OFFSET(0)]"
Expand Down Expand Up @@ -574,7 +574,7 @@ def compiles_approx(translator, op):
where = op.where

if where is not None:
arg = ops.Where(where, arg, ibis.NA)
arg = ops.IfElse(where, arg, ibis.NA)

return f"APPROX_QUANTILES({translator.translate(arg)}, 2)[OFFSET(1)]"

Expand All @@ -585,8 +585,8 @@ def translate(translator, op):
right = op.right

if (where := op.where) is not None:
left = ops.Where(where, left, None)
right = ops.Where(where, right, None)
left = ops.IfElse(where, left, None)
right = ops.IfElse(where, right, None)

left = translator.translate(
ops.Cast(left, dt.int64) if left.dtype.is_boolean() else left
Expand Down Expand Up @@ -648,15 +648,15 @@ def _zeroifnull(t, op):
def _array_agg(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.Where(where, arg, ibis.NA)
arg = ops.IfElse(where, arg, ibis.NA)
return f"ARRAY_AGG({t.translate(arg)} IGNORE NULLS)"


def _arg_min_max(sort_dir: Literal["ASC", "DESC"]):
def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
arg = op.arg
if (where := op.where) is not None:
arg = ops.Where(where, arg, None)
arg = ops.IfElse(where, arg, None)
arg = t.translate(arg)
key = t.translate(op.key)
return f"ARRAY_AGG({arg} IGNORE NULLS ORDER BY {key} {sort_dir} LIMIT 1)[SAFE_OFFSET(0)]"
Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def formatter(op, *, left, right, **_):
ops.E: "e",
# for more than 2 args this should be arrayGreatest|Least(array([]))
# because clickhouse"s greatest and least doesn"t support varargs
ops.Where: "if",
ops.IfElse: "if",
ops.ArrayLength: "length",
ops.Unnest: "arrayJoin",
ops.Degrees: "degrees",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
CASE WHEN t0.float_col > 0 THEN t0.int_col ELSE t0.bigint_col END AS "Where(Greater(float_col, 0), int_col, bigint_col)"
CASE WHEN t0.float_col > 0 THEN t0.int_col ELSE t0.bigint_col END AS "IfElse(Greater(float_col, 0), int_col, bigint_col)"
FROM functional_alltypes AS t0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SELECT
SUM(CASE WHEN isNull(t0.string_col) THEN 1 ELSE 0 END) AS "Sum(Where(IsNull(string_col), 1, 0))"
SUM(CASE WHEN isNull(t0.string_col) THEN 1 ELSE 0 END) AS "Sum(IfElse(IsNull(string_col), 1, 0))"
FROM functional_alltypes AS t0
16 changes: 8 additions & 8 deletions ibis/backends/dask/execution/indexing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Execution rules for ops.Where operations."""
"""Execution rules for ops.IfElse operations."""

from __future__ import annotations

Expand All @@ -10,10 +10,10 @@
from ibis.backends.pandas.execution.generic import pd_where


@execute_node.register(ops.Where, (dd.Series, *boolean_types), dd.Series, dd.Series)
@execute_node.register(ops.Where, (dd.Series, *boolean_types), dd.Series, simple_types)
@execute_node.register(ops.Where, (dd.Series, *boolean_types), simple_types, dd.Series)
@execute_node.register(ops.Where, (dd.Series, *boolean_types), type(None), type(None))
@execute_node.register(ops.IfElse, (dd.Series, *boolean_types), dd.Series, dd.Series)
@execute_node.register(ops.IfElse, (dd.Series, *boolean_types), dd.Series, simple_types)
@execute_node.register(ops.IfElse, (dd.Series, *boolean_types), simple_types, dd.Series)
@execute_node.register(ops.IfElse, (dd.Series, *boolean_types), type(None), type(None))
def execute_node_where(op, cond, true, false, **kwargs):
if any(isinstance(x, (dd.Series, dd.core.Scalar)) for x in (cond, true, false)):
return dd.map_partitions(pd_where, cond, true, false)
Expand All @@ -26,6 +26,6 @@ def execute_node_where(op, cond, true, false, **kwargs):
# promotion.
for typ in (str, *scalar_types):
for cond_typ in (dd.Series, *boolean_types):
execute_node.register(ops.Where, cond_typ, typ, typ)(execute_node_where)
execute_node.register(ops.Where, cond_typ, type(None), typ)(execute_node_where)
execute_node.register(ops.Where, cond_typ, typ, type(None))(execute_node_where)
execute_node.register(ops.IfElse, cond_typ, typ, typ)(execute_node_where)
execute_node.register(ops.IfElse, cond_typ, type(None), typ)(execute_node_where)
execute_node.register(ops.IfElse, cond_typ, typ, type(None))(execute_node_where)
2 changes: 1 addition & 1 deletion ibis/backends/datafusion/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def agg_udf(op, **kw):
# it so this will fail if `where` is in the function's signature.
#
# Filtering aggregates are not yet possible.
translate(arg if where is None else ops.Where(where, arg, NA), **kw)
translate(arg if where is None else ops.IfElse(where, arg, NA), **kw)
for argname, arg in zip(op.argnames, op.args)
if argname != "where"
)
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/druid/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

def _sign(t, op):
arg = op.arg
cond1 = ops.Where(ops.Greater(arg, 0), 1, -1)
cond2 = ops.Where(ops.Equals(arg, 0), 0, cond1)
cond1 = ops.IfElse(ops.Greater(arg, 0), 1, -1)
cond2 = ops.IfElse(ops.Equals(arg, 0), 0, cond1)
return t.translate(cond2)


Expand Down
2 changes: 1 addition & 1 deletion ibis/backends/flink/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _floor_divide(translator: ExprTranslator, op: ops.Node) -> str:
ops.ExtractSecond: _extract_field("second"), # equivalent to SECOND(timestamp)
# Other operations
ops.Literal: _literal,
ops.Where: _filter,
ops.IfElse: _filter,
ops.TimestampDiff: _timestamp_diff,
ops.TimestampFromUNIX: _timestamp_from_unix,
ops.Window: _window,
Expand Down
6 changes: 3 additions & 3 deletions ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def reduction_compiler(t, op):
nullable = arg.dtype.nullable
arg = ops.Cast(arg, dt.dtype(cast_type)(nullable=nullable))
else:
arg = ops.Where(arg, 1, 0)
arg = ops.IfElse(arg, 1, 0)

if where is not None:
arg = ops.Where(where, arg, None)
arg = ops.IfElse(where, arg, None)
return func(t.translate(arg))

return reduction_compiler
Expand Down Expand Up @@ -117,7 +117,7 @@ def _timestamp_truncate(t, op):
ops.Min: _reduction(sa.func.min),
ops.Sum: _reduction(sa.func.sum),
ops.Mean: _reduction(sa.func.avg, "float64"),
ops.Where: fixed_arity(sa.func.iif, 3),
ops.IfElse: fixed_arity(sa.func.iif, 3),
# string methods
ops.Capitalize: unary(
lambda arg: sa.func.concat(
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/mysql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _literal(_, op):

def _group_concat(t, op):
if op.where is not None:
arg = t.translate(ops.Where(op.where, op.arg, ibis.NA))
arg = t.translate(ops.IfElse(op.where, op.arg, ibis.NA))
else:
arg = t.translate(op.arg)
sep = t.translate(op.sep)
Expand Down Expand Up @@ -172,7 +172,7 @@ def compiles_mysql_trim(element, compiler, **kw):
ops.Literal: _literal,
ops.IfNull: fixed_arity(sa.func.ifnull, 2),
# static checks are not happy with using "if" as a property
ops.Where: fixed_arity(getattr(sa.func, "if"), 3),
ops.IfElse: fixed_arity(getattr(sa.func, "if"), 3),
# strings
ops.StringFind: _string_find,
ops.FindInSet: (
Expand Down
14 changes: 7 additions & 7 deletions ibis/backends/pandas/execution/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,10 +1268,10 @@ def pd_where(cond, true, false):
return false


@execute_node.register(ops.Where, (pd.Series, *boolean_types), pd.Series, pd.Series)
@execute_node.register(ops.Where, (pd.Series, *boolean_types), pd.Series, simple_types)
@execute_node.register(ops.Where, (pd.Series, *boolean_types), simple_types, pd.Series)
@execute_node.register(ops.Where, (pd.Series, *boolean_types), type(None), type(None))
@execute_node.register(ops.IfElse, (pd.Series, *boolean_types), pd.Series, pd.Series)
@execute_node.register(ops.IfElse, (pd.Series, *boolean_types), pd.Series, simple_types)
@execute_node.register(ops.IfElse, (pd.Series, *boolean_types), simple_types, pd.Series)
@execute_node.register(ops.IfElse, (pd.Series, *boolean_types), type(None), type(None))
def execute_node_where(op, cond, true, false, **kwargs):
return pd_where(cond, true, false)

Expand All @@ -1281,9 +1281,9 @@ def execute_node_where(op, cond, true, false, **kwargs):
# promotion.
for typ in (str, *scalar_types):
for cond_typ in (pd.Series, *boolean_types):
execute_node.register(ops.Where, cond_typ, typ, typ)(execute_node_where)
execute_node.register(ops.Where, cond_typ, type(None), typ)(execute_node_where)
execute_node.register(ops.Where, cond_typ, typ, type(None))(execute_node_where)
execute_node.register(ops.IfElse, cond_typ, typ, typ)(execute_node_where)
execute_node.register(ops.IfElse, cond_typ, type(None), typ)(execute_node_where)
execute_node.register(ops.IfElse, cond_typ, typ, type(None))(execute_node_where)


@execute_node.register(ops.DatabaseTable, PandasBackend)
Expand Down
16 changes: 8 additions & 8 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def nullifzero(op, **kw):
return pl.when(arg == 0).then(None).otherwise(arg)


@translate.register(ops.Where)
def where(op, **kw):
@translate.register(ops.IfElse)
def ifelse(op, **kw):
bool_expr = translate(op.bool_expr, **kw)
true_expr = translate(op.true_expr, **kw)
false_null_expr = translate(op.false_null_expr, **kw)
Expand Down Expand Up @@ -739,8 +739,8 @@ def correlation(op, **kw):
y = ops.Cast(y, dt.Int32(nullable=y_type.nullable))

if (where := op.where) is not None:
x = ops.Where(where, x, None)
y = ops.Where(where, y, None)
x = ops.IfElse(where, x, None)
y = ops.IfElse(where, y, None)

return pl.corr(translate(x, **kw), translate(y, **kw))

Expand Down Expand Up @@ -1105,7 +1105,7 @@ def execute_hash(op, **kw):
def execute_not_all(op, **kw):
arg = op.arg
if (op_where := op.where) is not None:
arg = ops.Where(op_where, arg, None)
arg = ops.IfElse(op_where, arg, None)

return translate(arg, **kw).all().is_not()

Expand All @@ -1114,7 +1114,7 @@ def execute_not_all(op, **kw):
def execute_not_any(op, **kw):
arg = op.arg
if (op_where := op.where) is not None:
arg = ops.Where(op_where, arg, None)
arg = ops.IfElse(op_where, arg, None)

return translate(arg, **kw).any().is_not()

Expand All @@ -1124,8 +1124,8 @@ def _arg_min_max(op, func, **kw):
arg = op.arg

if (op_where := op.where) is not None:
key = ops.Where(op_where, key, None)
arg = ops.Where(op_where, arg, None)
key = ops.IfElse(op_where, key, None)
arg = ops.IfElse(op_where, arg, None)

translate_arg = translate(arg, **kw)
translate_key = translate(key, **kw)
Expand Down
Loading

0 comments on commit a64b7ad

Please sign in to comment.