Skip to content

Commit

Permalink
fix(case): fix dshape, error on noncomparable and empty cases
Browse files Browse the repository at this point in the history
This is pinning down the expected behavior for cases before tackling
the case() to cases() switch in
#9096
so that PR can be simpler

I move the validation for comparable-ness down into the operation so that
the logic is consolidated to one place.
in #9096 there might be multiple places that construct an ops.SimpleCase, and we don't want
to have to implement the validation in all
calling locations.

We could consider relaxing the limitation for non-empty cases later, but for now lets be strict.

I already fixed the shape of ops.SearchedCase in #9334,
but it looks like in that PR I forgot to also fix ops.SimpleCase, so I do that fix here.
  • Loading branch information
NickCrews committed Jul 16, 2024
1 parent 7a0b21e commit aee467b
Show file tree
Hide file tree
Showing 23 changed files with 111 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ SELECT
`t0`.`value` <= 3
)
THEN 1
ELSE CAST(NULL AS INT64)
ELSE NULL
END AS `tmp`
FROM `t` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ FROM (
`t0`.`f` <= 50
)
THEN 3
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `tier`,
COUNT(*) AS `CountStar(alltypes)`
FROM `alltypes` AS `t0`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` < 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ SELECT
THEN 3
WHEN 50 <= `t0`.`f`
THEN 4
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 3
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CASE
WHEN `t0`.`f` <= 10
THEN 0
WHEN 10 < `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f, ())`
CASE WHEN `t0`.`f` <= 10 THEN 0 WHEN 10 < `t0`.`f` THEN 1 ELSE NULL END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CASE
WHEN `t0`.`f` < 10
THEN 0
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f, ())`
CASE WHEN `t0`.`f` < 10 THEN 0 WHEN 10 <= `t0`.`f` THEN 1 ELSE NULL END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CAST(CASE
WHEN `t0`.`f` < 10
THEN 0
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS INT) AS `Cast(Bucket(f, ()), int32)`
CAST(CASE WHEN `t0`.`f` < 10 THEN 0 WHEN 10 <= `t0`.`f` THEN 1 ELSE NULL END AS INT) AS `Cast(Bucket(f, ()), int32)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CAST(CASE
WHEN `t0`.`f` < 10
THEN 0
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS DOUBLE) AS `Cast(Bucket(f, ()), float64)`
CAST(CASE WHEN `t0`.`f` < 10 THEN 0 WHEN 10 <= `t0`.`f` THEN 1 ELSE NULL END AS DOUBLE) AS `Cast(Bucket(f, ()), float64)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 3
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ SELECT
THEN 3
WHEN 50 < `t0`.`f`
THEN 4
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ SELECT
THEN `t0`.`d` * 2
WHEN `t0`.`c` < 0
THEN `t0`.`a` * 2
ELSE CAST(NULL AS BIGINT)
END AS `SearchedCase((Greater(f, 0), Less(c, 0)), (Multiply(d, 2), Multiply(a, 2)), Cast(None, int64))`
ELSE NULL
END AS `SearchedCase((Greater(f, 0), Less(c, 0)), (Multiply(d, 2), Multiply(a, 2)), None)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ibis.case()
.when(alltypes.g == lit, lit2)
.when(alltypes.g == lit1, alltypes.g)
.else_(ibis.literal(None).cast("string"))
.else_(ibis.literal(None))
.end()
.name("col2"),
alltypes.a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SELECT
THEN 'bar'
WHEN "t0"."g" = 'baz'
THEN "t0"."g"
ELSE CAST(NULL AS TEXT)
ELSE NULL
END AS "col2",
"t0"."a",
"t0"."b",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ SELECT
THEN "t0"."d" * CAST(2 AS TINYINT)
WHEN "t0"."c" < CAST(0 AS TINYINT)
THEN "t0"."a" * CAST(2 AS TINYINT)
ELSE CAST(NULL AS BIGINT)
ELSE NULL
END AS "tmp"
FROM "alltypes" AS "t0"
6 changes: 0 additions & 6 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,6 @@ def when(self, case_expr: Any, result_expr: Any) -> Self:
case_expr = ibis.literal(case_expr)
if not isinstance(result_expr, ir.Value):
result_expr = ibis.literal(result_expr)

if not rlz.comparable(self.base, case_expr.op()):
raise TypeError(
f"Base expression {rlz._arg_type_error_format(self.base)} and "
f"case {rlz._arg_type_error_format(case_expr)} are not comparable"
)
return self.copy(
cases=self.cases + (case_expr,), results=self.results + (result_expr,)
)
Expand Down
27 changes: 17 additions & 10 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,24 @@ class SimpleCase(Value):
"""Simple case statement."""

base: Value
cases: VarTuple[Value]
results: VarTuple[Value]
cases: Annotated[VarTuple[Value], Length(at_least=1)]
results: Annotated[VarTuple[Value], Length(at_least=1)]
default: Value

shape = rlz.shape_like("base")

def __init__(self, cases, results, **kwargs):
def __init__(self, base, cases, results, default):
assert len(cases) == len(results)
super().__init__(cases=cases, results=results, **kwargs)
for case in cases:
if not rlz.comparable(base, case):
raise TypeError(
f"Base expression {rlz.arg_type_error_format(base)} and "
f"case {rlz.arg_type_error_format(case)} are not comparable"
)
super().__init__(base=base, cases=cases, results=results, default=default)

@attribute
def shape(self):
exprs = [self.base, *self.cases, *self.results, self.default]
return rlz.highest_precedence_shape(exprs)

@attribute
def dtype(self):
Expand All @@ -309,14 +318,12 @@ def dtype(self):
class SearchedCase(Value):
"""Searched case statement."""

cases: VarTuple[Value[dt.Boolean]]
results: VarTuple[Value]
cases: Annotated[VarTuple[Value[dt.Boolean]], Length(at_least=1)]
results: Annotated[VarTuple[Value], Length(at_least=1)]
default: Value

def __init__(self, cases, results, default):
assert len(cases) == len(results)
if default.dtype.is_null():
default = Cast(default, rlz.highest_precedence_dtype(results))
super().__init__(cases=cases, results=results, default=default)

@attribute
Expand Down
12 changes: 6 additions & 6 deletions ibis/expr/operations/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self, left, right):
"""
if not rlz.comparable(left, right):
raise IbisTypeError(
f"Arguments {rlz._arg_type_error_format(left)} and "
f"{rlz._arg_type_error_format(right)} are not comparable"
f"Arguments {rlz.arg_type_error_format(left)} and "
f"{rlz.arg_type_error_format(right)} are not comparable"
)
super().__init__(left=left, right=right)

Expand Down Expand Up @@ -121,13 +121,13 @@ class Between(Value):
def __init__(self, arg, lower_bound, upper_bound):
if not rlz.comparable(arg, lower_bound):
raise ValidationError(
f"Arguments {rlz._arg_type_error_format(arg)} and "
f"{rlz._arg_type_error_format(lower_bound)} are not comparable"
f"Arguments {rlz.arg_type_error_format(arg)} and "
f"{rlz.arg_type_error_format(lower_bound)} are not comparable"
)
if not rlz.comparable(arg, upper_bound):
raise ValidationError(
f"Arguments {rlz._arg_type_error_format(arg)} and "
f"{rlz._arg_type_error_format(upper_bound)} are not comparable"
f"Arguments {rlz.arg_type_error_format(arg)} and "
f"{rlz.arg_type_error_format(upper_bound)} are not comparable"
)
super().__init__(arg=arg, lower_bound=lower_bound, upper_bound=upper_bound)

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _promote_interval_resolution(units: list[IntervalUnit]) -> IntervalUnit:
raise AssertionError("unreachable")


def _arg_type_error_format(op):
def arg_type_error_format(op: ops.Value) -> str:
if isinstance(op, ops.Literal):
return f"Literal({op.value}):{op.dtype}"
else:
Expand Down
70 changes: 68 additions & 2 deletions ibis/tests/expr/test_case.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import _
from ibis.common.annotations import SignatureValidationError
from ibis.tests.util import assert_equal, assert_pickle_roundtrip


Expand Down Expand Up @@ -44,6 +47,41 @@ def test_ifelse_function_deferred(table):
assert res.equals(sol)


def test_case_dshape(table):
assert isinstance(ibis.case().when(True, "bar").when(False, "bar").end(), ir.Scalar)
assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.Scalar)
assert isinstance(
ibis.case().when(table.b == 9, None).else_("bar").end(), ir.Column
)
assert isinstance(ibis.case().when(True, table.a).else_(42).end(), ir.Column)
assert isinstance(ibis.case().when(True, 42).else_(table.a).end(), ir.Column)
assert isinstance(ibis.case().when(True, table.a).else_(table.b).end(), ir.Column)

assert isinstance(ibis.literal(5).case().when(9, 42).end(), ir.Scalar)
assert isinstance(ibis.literal(5).case().when(9, 42).else_(43).end(), ir.Scalar)
assert isinstance(ibis.literal(5).case().when(table.a, 42).end(), ir.Column)
assert isinstance(ibis.literal(5).case().when(9, table.a).end(), ir.Column)
assert isinstance(ibis.literal(5).case().when(table.a, table.b).end(), ir.Column)
assert isinstance(
ibis.literal(5).case().when(9, 42).else_(table.a).end(), ir.Column
)
assert isinstance(table.a.case().when(9, 42).end(), ir.Column)
assert isinstance(table.a.case().when(table.b, 42).end(), ir.Column)
assert isinstance(table.a.case().when(9, table.b).end(), ir.Column)
assert isinstance(table.a.case().when(table.a, table.b).end(), ir.Column)


def test_case_dtype():
assert isinstance(
ibis.case().when(True, "bar").when(False, "bar").end(), ir.StringValue
)
assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.StringValue)
with pytest.raises(TypeError):
assert ibis.case().when(True, 5).when(False, "bar").end()
with pytest.raises(TypeError):
assert ibis.case().when(True, 5).else_("bar").end()


def test_simple_case_expr(table):
case1, result1 = "foo", table.a
case2, result2 = "bar", table.c
Expand Down Expand Up @@ -162,8 +200,6 @@ def test_multiple_case_null_else(table):
op = expr.op()
assert isinstance(expr, ir.StringColumn)
assert isinstance(op.default.to_expr(), ir.Value)
assert isinstance(op.default, ops.Cast)
assert op.default.to == dt.string


def test_case_mixed_type():
Expand All @@ -177,3 +213,33 @@ def test_case_mixed_type():
)
result = t0[expr]
assert result["label"].type().equals(dt.string)


def test_err_on_nonbool(table):
with pytest.raises(SignatureValidationError):
ibis.case().when(table.a, "bar").else_("baz").end()


@pytest.mark.xfail(reason="Literal('foo', type=bool), should error, but doesn't")
def test_err_on_nonbool2():
with pytest.raises(SignatureValidationError):
ibis.case().when("foo", "bar").else_("baz").end()


def test_err_on_noncomparable(table):
table.a.case().when(8, "bar").end()
table.a.case().when(-8, "bar").end()
# Can't compare an int to a string
with pytest.raises(TypeError):
table.a.case().when("foo", "bar").end()


def test_err_on_empty_cases(table):
with pytest.raises(SignatureValidationError):
ibis.case().end()
with pytest.raises(SignatureValidationError):
ibis.case().else_(42).end()
with pytest.raises(SignatureValidationError):
table.a.case().end()
with pytest.raises(SignatureValidationError):
table.a.case().else_(42).end()

0 comments on commit aee467b

Please sign in to comment.