diff --git a/ibis/expr/api.py b/ibis/expr/api.py index bef8aa634db0..30f2fa35b958 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -956,16 +956,40 @@ def case() -> bl.SearchedCaseBuilder: """Begin constructing a case expression. Use the `.when` method on the resulting object followed by `.end` to create a - complete case. + complete case expression. Examples -------- >>> import ibis - >>> cond1 = ibis.literal(1) == 1 - >>> cond2 = ibis.literal(2) == 1 - >>> expr = ibis.case().when(cond1, 3).when(cond2, 4).end() - >>> expr - SearchedCase(...) + >>> from ibis import _ + >>> ibis.options.interactive = True + >>> t = ibis.memtable( + ... { + ... "left": [1, 2, 3, 4], + ... "symbol": ["+", "-", "*", "/"], + ... "right": [5, 6, 7, 8], + ... } + ... ) + >>> t.mutate( + ... result=( + ... ibis.case() + ... .when(_.symbol == "+", _.left + _.right) + ... .when(_.symbol == "-", _.left - _.right) + ... .when(_.symbol == "*", _.left * _.right) + ... .when(_.symbol == "/", _.left / _.right) + ... .end() + ... ) + ... ) + ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ + ┃ left ┃ symbol ┃ right ┃ result ┃ + ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ + │ int64 │ string │ int64 │ float64 │ + ├───────┼────────┼───────┼─────────┤ + │ 1 │ + │ 5 │ 6.0 │ + │ 2 │ - │ 6 │ -4.0 │ + │ 3 │ * │ 7 │ 21.0 │ + │ 4 │ / │ 8 │ 0.5 │ + └───────┴────────┴───────┴─────────┘ Returns ------- diff --git a/ibis/expr/builders.py b/ibis/expr/builders.py index e3287cb15373..49742242ba19 100644 --- a/ibis/expr/builders.py +++ b/ibis/expr/builders.py @@ -1,8 +1,9 @@ from __future__ import annotations import math -from typing import TYPE_CHECKING, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union +import ibis import ibis.expr.datashape as ds import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -13,8 +14,8 @@ from ibis.common.exceptions import IbisInputError from ibis.common.grounds import Concrete from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.deferred import Deferred # noqa: TCH001 -from ibis.expr.operations.core import Value # noqa: TCH001 +from ibis.expr.deferred import Deferred +from ibis.expr.deferred import _resolve as resolve from ibis.expr.operations.relations import Relation # noqa: TCH001 from ibis.expr.types.relations import bind_expr @@ -26,56 +27,82 @@ class Builder(Concrete): pass -class CaseBuilder(Builder): - results: VarTuple[Value] = () - default: Optional[ops.Value] = None +class DeferredCase(Deferred): + """A deferred case statement.""" - def type(self): - return rlz.highest_precedence_dtype(self.results) + __slots__ = ("_builder",) + _builder: SearchedCaseBuilder - def when(self, case_expr, result_expr) -> Self: - """Add a new case-result pair. + def __init__(self, builder: SearchedCaseBuilder): + self._builder = builder + + def __repr__(self) -> str: + return "" + + def _resolve(self, param) -> Any: + cases = tuple(resolve(c, param) for c in self._builder.cases) + results = tuple(resolve(r, param) for r in self._builder.results) + default = resolve(self._builder.default, param) + return self._builder.copy(cases=cases, results=results, default=default).end() + + +class SearchedCaseBuilder(Builder): + """A case builder, used for constructing `ibis.case()` expressions.""" + + cases: VarTuple[Union[Deferred, ops.Value[dt.Boolean, ds.Any]]] = () + results: VarTuple[Union[Deferred, ops.Value]] = () + default: Optional[Union[None, Deferred, ops.Value]] = None + + def when(self, case_expr: Any, result_expr: Any) -> Self: + """Add a new condition and result to the `CASE` expression. Parameters ---------- case_expr - Expression to equality-compare with base expression. Must be - comparable with the base. + Predicate expression to use for this case. result_expr Value when the case predicate evaluates to true. """ - cases = self.cases + (case_expr,) - results = self.results + (result_expr,) - return self.copy(cases=cases, results=results) - - def else_(self, result_expr) -> Self: - """Construct an `ELSE` expression.""" - return self.copy(default=result_expr) + return self.copy( + cases=self.cases + (case_expr,), results=self.results + (result_expr,) + ) - def end(self) -> ir.Value: - default = self.default - if default is None: - default = ir.null().cast(self.type()) + def else_(self, result_expr: Any) -> Self: + """Add a default value for the `CASE` expression. - kwargs = dict(zip(self.__argnames__, self.__args__)) - kwargs["default"] = default + Parameters + ---------- + result_expr + Value to use when all case predicates evaluate to false. + """ + return self.copy(default=result_expr) - return self.__type__(**kwargs).to_expr() + def end(self) -> ir.Value | Deferred: + """Finish the `CASE` expression.""" + if ( + isinstance(self.default, Deferred) + or any(isinstance(c, Deferred) for c in self.cases) + or any(isinstance(r, Deferred) for r in self.results) + ): + return DeferredCase(self) + if (default := self.default) is None: + default = ibis.null().cast(rlz.highest_precedence_dtype(self.results)) + return ops.SearchedCase( + cases=self.cases, results=self.results, default=default + ).to_expr() -class SearchedCaseBuilder(CaseBuilder): - __type__ = ops.SearchedCase - cases: VarTuple[Value[dt.Boolean, ds.Any]] = () +class SimpleCaseBuilder(Builder): + """A case builder, used for constructing `Column.case()` expressions.""" -class SimpleCaseBuilder(CaseBuilder): - __type__ = ops.SimpleCase base: ops.Value - cases: VarTuple[Value] = () + cases: VarTuple[ops.Value] = () + results: VarTuple[ops.Value] = () + default: Optional[Union[None, ops.Value]] = None - @annotated - def when(self, case_expr: Value, result_expr: Value): - """Add a new case-result pair. + def when(self, case_expr: Any, result_expr: Any) -> Self: + """Add a new condition and result to the `CASE` expression. Parameters ---------- @@ -85,12 +112,37 @@ def when(self, case_expr: Value, result_expr: Value): result_expr Value when the case predicate evaluates to true. """ - if not rlz.comparable(self.base, case_expr): + if not isinstance(case_expr, ir.Value): + case_expr = ibis.literal(case_expr) + if not isinstance(result_expr, ir.Value): + result_expr = ibis.literal(result_expr) + + if not rlz.comparable(self.base, case_expr.op()): raise TypeError( f"Base expression {rlz._arg_type_error_format(self.base)} and " f"case {rlz._arg_type_error_format(case_expr)} are not comparable" ) - return super().when(case_expr, result_expr) + return self.copy( + cases=self.cases + (case_expr,), results=self.results + (result_expr,) + ) + + def else_(self, result_expr: Any) -> Self: + """Add a default value for the `CASE` expression. + + Parameters + ---------- + result_expr + Value to use when all case predicates evaluate to false. + """ + return self.copy(default=result_expr) + + def end(self) -> ir.Value: + """Finish the `CASE` expression.""" + if (default := self.default) is None: + default = ibis.null().cast(rlz.highest_precedence_dtype(self.results)) + return ops.SimpleCase( + cases=self.cases, results=self.results, default=default, base=self.base + ).to_expr() RowsWindowBoundary = ops.WindowBoundary[dt.Integer] @@ -112,9 +164,9 @@ class WindowBuilder(Builder): how: Literal["rows", "range"] = "rows" start: Optional[RangeWindowBoundary] = None end: Optional[RangeWindowBoundary] = None - groupings: VarTuple[Union[str, Deferred, Value]] = () - orderings: VarTuple[Union[str, Deferred, Value]] = () - max_lookback: Optional[Value[dt.Interval]] = None + groupings: VarTuple[Union[str, Deferred, ops.Value]] = () + orderings: VarTuple[Union[str, Deferred, ops.Value]] = () + max_lookback: Optional[ops.Value[dt.Interval]] = None def _maybe_cast_boundary(self, boundary, dtype): if boundary.dtype == dtype: diff --git a/ibis/tests/expr/test_case.py b/ibis/tests/expr/test_case.py index 1ef15a8fa69e..70d3b5dfb710 100644 --- a/ibis/tests/expr/test_case.py +++ b/ibis/tests/expr/test_case.py @@ -67,29 +67,57 @@ def test_simple_case_expr(table): def test_multiple_case_expr(table): - case1 = table.a == 5 - case2 = table.b == 128 - case3 = table.c == 1000 + expr = ( + ibis.case() + .when(table.a == 5, table.f) + .when(table.b == 128, table.b * 2) + .when(table.c == 1000, table.e) + .else_(table.d) + .end() + ) - result1 = table.f - result2 = table.b * 2 - result3 = table.e + # deferred cases + deferred = ( + ibis.case() + .when(_.a == 5, table.f) + .when(_.b == 128, table.b * 2) + .when(_.c == 1000, table.e) + .else_(table.d) + .end() + ) + expr2 = deferred.resolve(table) - default = table.d + # deferred results + expr3 = ( + ibis.case() + .when(table.a == 5, _.f) + .when(table.b == 128, _.b * 2) + .when(table.c == 1000, _.e) + .else_(table.d) + .end() + .resolve(table) + ) - expr = ( + # deferred default + expr4 = ( ibis.case() - .when(case1, result1) - .when(case2, result2) - .when(case3, result3) - .else_(default) + .when(table.a == 5, table.f) + .when(table.b == 128, table.b * 2) + .when(table.c == 1000, table.e) + .else_(_.d) .end() + .resolve(table) ) + assert repr(deferred) == "" + assert expr.equals(expr2) + assert expr.equals(expr3) + assert expr.equals(expr4) + op = expr.op() assert isinstance(expr, ir.FloatingColumn) assert isinstance(op, ops.SearchedCase) - assert op.default == default.op() + assert op.default == table.d.op() def test_pickle_multiple_case_node(table):