Skip to content

Commit

Permalink
feat(api): support deferred arguments in ibis.case()
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist authored and cpcloud committed Sep 29, 2023
1 parent 13f593b commit 6f9f7c5
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 59 deletions.
36 changes: 30 additions & 6 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,16 +956,40 @@ def case() -> bl.SearchedCaseBuilder:
"""Begin constructing a case expression.
Use the `.when` method on the resulting object followed by `.end` to create a
complete case.
complete case expression.
Examples
--------
>>> import ibis
>>> cond1 = ibis.literal(1) == 1
>>> cond2 = ibis.literal(2) == 1
>>> expr = ibis.case().when(cond1, 3).when(cond2, 4).end()
>>> expr
SearchedCase(...)
>>> from ibis import _
>>> ibis.options.interactive = True
>>> t = ibis.memtable(
... {
... "left": [1, 2, 3, 4],
... "symbol": ["+", "-", "*", "/"],
... "right": [5, 6, 7, 8],
... }
... )
>>> t.mutate(
... result=(
... ibis.case()
... .when(_.symbol == "+", _.left + _.right)
... .when(_.symbol == "-", _.left - _.right)
... .when(_.symbol == "*", _.left * _.right)
... .when(_.symbol == "/", _.left / _.right)
... .end()
... )
... )
┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓
┃ left ┃ symbol ┃ right ┃ result ┃
┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩
│ int64 │ string │ int64 │ float64 │
├───────┼────────┼───────┼─────────┤
│ 1 │ + │ 5 │ 6.0 │
│ 2 │ - │ 6 │ -4.0 │
│ 3 │ * │ 7 │ 21.0 │
│ 4 │ / │ 8 │ 0.5 │
└───────┴────────┴───────┴─────────┘
Returns
-------
Expand Down
132 changes: 92 additions & 40 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import math
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union

import ibis
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
Expand All @@ -13,8 +14,8 @@
from ibis.common.exceptions import IbisInputError
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.deferred import Deferred # noqa: TCH001
from ibis.expr.operations.core import Value # noqa: TCH001
from ibis.expr.deferred import Deferred
from ibis.expr.deferred import _resolve as resolve
from ibis.expr.operations.relations import Relation # noqa: TCH001
from ibis.expr.types.relations import bind_expr

Expand All @@ -26,56 +27,82 @@ class Builder(Concrete):
pass


class CaseBuilder(Builder):
results: VarTuple[Value] = ()
default: Optional[ops.Value] = None
class DeferredCase(Deferred):
"""A deferred case statement."""

def type(self):
return rlz.highest_precedence_dtype(self.results)
__slots__ = ("_builder",)
_builder: SearchedCaseBuilder

def when(self, case_expr, result_expr) -> Self:
"""Add a new case-result pair.
def __init__(self, builder: SearchedCaseBuilder):
self._builder = builder

def __repr__(self) -> str:
return "<case>"

def _resolve(self, param) -> Any:
cases = tuple(resolve(c, param) for c in self._builder.cases)
results = tuple(resolve(r, param) for r in self._builder.results)
default = resolve(self._builder.default, param)
return self._builder.copy(cases=cases, results=results, default=default).end()


class SearchedCaseBuilder(Builder):
"""A case builder, used for constructing `ibis.case()` expressions."""

cases: VarTuple[Union[Deferred, ops.Value[dt.Boolean, ds.Any]]] = ()
results: VarTuple[Union[Deferred, ops.Value]] = ()
default: Optional[Union[None, Deferred, ops.Value]] = None

def when(self, case_expr: Any, result_expr: Any) -> Self:
"""Add a new condition and result to the `CASE` expression.
Parameters
----------
case_expr
Expression to equality-compare with base expression. Must be
comparable with the base.
Predicate expression to use for this case.
result_expr
Value when the case predicate evaluates to true.
"""
cases = self.cases + (case_expr,)
results = self.results + (result_expr,)
return self.copy(cases=cases, results=results)

def else_(self, result_expr) -> Self:
"""Construct an `ELSE` expression."""
return self.copy(default=result_expr)
return self.copy(
cases=self.cases + (case_expr,), results=self.results + (result_expr,)
)

def end(self) -> ir.Value:
default = self.default
if default is None:
default = ir.null().cast(self.type())
def else_(self, result_expr: Any) -> Self:
"""Add a default value for the `CASE` expression.
kwargs = dict(zip(self.__argnames__, self.__args__))
kwargs["default"] = default
Parameters
----------
result_expr
Value to use when all case predicates evaluate to false.
"""
return self.copy(default=result_expr)

return self.__type__(**kwargs).to_expr()
def end(self) -> ir.Value | Deferred:
"""Finish the `CASE` expression."""
if (
isinstance(self.default, Deferred)
or any(isinstance(c, Deferred) for c in self.cases)
or any(isinstance(r, Deferred) for r in self.results)
):
return DeferredCase(self)

if (default := self.default) is None:
default = ibis.null().cast(rlz.highest_precedence_dtype(self.results))
return ops.SearchedCase(
cases=self.cases, results=self.results, default=default
).to_expr()

class SearchedCaseBuilder(CaseBuilder):
__type__ = ops.SearchedCase
cases: VarTuple[Value[dt.Boolean, ds.Any]] = ()

class SimpleCaseBuilder(Builder):
"""A case builder, used for constructing `Column.case()` expressions."""

class SimpleCaseBuilder(CaseBuilder):
__type__ = ops.SimpleCase
base: ops.Value
cases: VarTuple[Value] = ()
cases: VarTuple[ops.Value] = ()
results: VarTuple[ops.Value] = ()
default: Optional[Union[None, ops.Value]] = None

@annotated
def when(self, case_expr: Value, result_expr: Value):
"""Add a new case-result pair.
def when(self, case_expr: Any, result_expr: Any) -> Self:
"""Add a new condition and result to the `CASE` expression.
Parameters
----------
Expand All @@ -85,12 +112,37 @@ def when(self, case_expr: Value, result_expr: Value):
result_expr
Value when the case predicate evaluates to true.
"""
if not rlz.comparable(self.base, case_expr):
if not isinstance(case_expr, ir.Value):
case_expr = ibis.literal(case_expr)
if not isinstance(result_expr, ir.Value):
result_expr = ibis.literal(result_expr)

if not rlz.comparable(self.base, case_expr.op()):
raise TypeError(
f"Base expression {rlz._arg_type_error_format(self.base)} and "
f"case {rlz._arg_type_error_format(case_expr)} are not comparable"
)
return super().when(case_expr, result_expr)
return self.copy(
cases=self.cases + (case_expr,), results=self.results + (result_expr,)
)

def else_(self, result_expr: Any) -> Self:
"""Add a default value for the `CASE` expression.
Parameters
----------
result_expr
Value to use when all case predicates evaluate to false.
"""
return self.copy(default=result_expr)

def end(self) -> ir.Value:
"""Finish the `CASE` expression."""
if (default := self.default) is None:
default = ibis.null().cast(rlz.highest_precedence_dtype(self.results))
return ops.SimpleCase(
cases=self.cases, results=self.results, default=default, base=self.base
).to_expr()


RowsWindowBoundary = ops.WindowBoundary[dt.Integer]
Expand All @@ -112,9 +164,9 @@ class WindowBuilder(Builder):
how: Literal["rows", "range"] = "rows"
start: Optional[RangeWindowBoundary] = None
end: Optional[RangeWindowBoundary] = None
groupings: VarTuple[Union[str, Deferred, Value]] = ()
orderings: VarTuple[Union[str, Deferred, Value]] = ()
max_lookback: Optional[Value[dt.Interval]] = None
groupings: VarTuple[Union[str, Deferred, ops.Value]] = ()
orderings: VarTuple[Union[str, Deferred, ops.Value]] = ()
max_lookback: Optional[ops.Value[dt.Interval]] = None

def _maybe_cast_boundary(self, boundary, dtype):
if boundary.dtype == dtype:
Expand Down
54 changes: 41 additions & 13 deletions ibis/tests/expr/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,57 @@ def test_simple_case_expr(table):


def test_multiple_case_expr(table):
case1 = table.a == 5
case2 = table.b == 128
case3 = table.c == 1000
expr = (
ibis.case()
.when(table.a == 5, table.f)
.when(table.b == 128, table.b * 2)
.when(table.c == 1000, table.e)
.else_(table.d)
.end()
)

result1 = table.f
result2 = table.b * 2
result3 = table.e
# deferred cases
deferred = (
ibis.case()
.when(_.a == 5, table.f)
.when(_.b == 128, table.b * 2)
.when(_.c == 1000, table.e)
.else_(table.d)
.end()
)
expr2 = deferred.resolve(table)

default = table.d
# deferred results
expr3 = (
ibis.case()
.when(table.a == 5, _.f)
.when(table.b == 128, _.b * 2)
.when(table.c == 1000, _.e)
.else_(table.d)
.end()
.resolve(table)
)

expr = (
# deferred default
expr4 = (
ibis.case()
.when(case1, result1)
.when(case2, result2)
.when(case3, result3)
.else_(default)
.when(table.a == 5, table.f)
.when(table.b == 128, table.b * 2)
.when(table.c == 1000, table.e)
.else_(_.d)
.end()
.resolve(table)
)

assert repr(deferred) == "<case>"
assert expr.equals(expr2)
assert expr.equals(expr3)
assert expr.equals(expr4)

op = expr.op()
assert isinstance(expr, ir.FloatingColumn)
assert isinstance(op, ops.SearchedCase)
assert op.default == default.op()
assert op.default == table.d.op()


def test_pickle_multiple_case_node(table):
Expand Down

0 comments on commit 6f9f7c5

Please sign in to comment.