Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(case): fix dshape, error on noncomparable and empty cases #9559

Merged
merged 2 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ SELECT
`t0`.`value` <= 3
)
THEN 1
ELSE CAST(NULL AS INT64)
ELSE NULL
END AS `tmp`
FROM `t` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ FROM (
`t0`.`f` <= 50
)
THEN 3
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `tier`,
COUNT(*) AS `CountStar(alltypes)`
FROM `alltypes` AS `t0`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` < 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ SELECT
THEN 3
WHEN 50 <= `t0`.`f`
THEN 4
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 3
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CASE
WHEN `t0`.`f` <= 10
THEN 0
WHEN 10 < `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f, ())`
CASE WHEN `t0`.`f` <= 10 THEN 0 WHEN 10 < `t0`.`f` THEN 1 ELSE NULL END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 2
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CASE
WHEN `t0`.`f` < 10
THEN 0
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS `Bucket(f, ())`
CASE WHEN `t0`.`f` < 10 THEN 0 WHEN 10 <= `t0`.`f` THEN 1 ELSE NULL END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CAST(CASE
WHEN `t0`.`f` < 10
THEN 0
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS INT) AS `Cast(Bucket(f, ()), int32)`
CAST(CASE WHEN `t0`.`f` < 10 THEN 0 WHEN 10 <= `t0`.`f` THEN 1 ELSE NULL END AS INT) AS `Cast(Bucket(f, ()), int32)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
SELECT
CAST(CASE
WHEN `t0`.`f` < 10
THEN 0
WHEN 10 <= `t0`.`f`
THEN 1
ELSE CAST(NULL AS TINYINT)
END AS DOUBLE) AS `Cast(Bucket(f, ()), float64)`
CAST(CASE WHEN `t0`.`f` < 10 THEN 0 WHEN 10 <= `t0`.`f` THEN 1 ELSE NULL END AS DOUBLE) AS `Cast(Bucket(f, ()), float64)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ SELECT
`t0`.`f` <= 50
)
THEN 3
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ SELECT
THEN 3
WHEN 50 < `t0`.`f`
THEN 4
ELSE CAST(NULL AS TINYINT)
ELSE NULL
END AS `Bucket(f, ())`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ SELECT
THEN `t0`.`d` * 2
WHEN `t0`.`c` < 0
THEN `t0`.`a` * 2
ELSE CAST(NULL AS BIGINT)
END AS `SearchedCase((Greater(f, 0), Less(c, 0)), (Multiply(d, 2), Multiply(a, 2)), Cast(None, int64))`
ELSE NULL
END AS `SearchedCase((Greater(f, 0), Less(c, 0)), (Multiply(d, 2), Multiply(a, 2)), None)`
FROM `alltypes` AS `t0`
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
ibis.case()
.when(alltypes.g == lit, lit2)
.when(alltypes.g == lit1, alltypes.g)
.else_(ibis.literal(None).cast("string"))
.else_(ibis.literal(None))
.end()
.name("col2"),
alltypes.a,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ SELECT
THEN 'bar'
WHEN "t0"."g" = 'baz'
THEN "t0"."g"
ELSE CAST(NULL AS TEXT)
ELSE NULL
END AS "col2",
"t0"."a",
"t0"."b",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ SELECT
THEN "t0"."d" * 2
WHEN "t0"."c" < 0
THEN "t0"."a" * 2
ELSE CAST(NULL AS BIGINT)
ELSE NULL
END AS "tmp"
FROM "alltypes" AS "t0"
6 changes: 0 additions & 6 deletions ibis/expr/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,6 @@ def when(self, case_expr: Any, result_expr: Any) -> Self:
case_expr = ibis.literal(case_expr)
if not isinstance(result_expr, ir.Value):
result_expr = ibis.literal(result_expr)

if not rlz.comparable(self.base, case_expr.op()):
raise TypeError(
f"Base expression {rlz._arg_type_error_format(self.base)} and "
f"case {rlz._arg_type_error_format(case_expr)} are not comparable"
)
return self.copy(
cases=self.cases + (case_expr,), results=self.results + (result_expr,)
)
Expand Down
27 changes: 17 additions & 10 deletions ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,15 +289,24 @@ class SimpleCase(Value):
"""Simple case statement."""

base: Value
cases: VarTuple[Value]
results: VarTuple[Value]
cases: Annotated[VarTuple[Value], Length(at_least=1)]
results: Annotated[VarTuple[Value], Length(at_least=1)]
default: Value

shape = rlz.shape_like("base")

def __init__(self, cases, results, **kwargs):
def __init__(self, base, cases, results, default):
assert len(cases) == len(results)
super().__init__(cases=cases, results=results, **kwargs)
for case in cases:
if not rlz.comparable(base, case):
raise TypeError(
f"Base expression {rlz.arg_type_error_format(base)} and "
f"case {rlz.arg_type_error_format(case)} are not comparable"
)
super().__init__(base=base, cases=cases, results=results, default=default)

@attribute
def shape(self):
exprs = [self.base, *self.cases, *self.results, self.default]
return rlz.highest_precedence_shape(exprs)

@attribute
def dtype(self):
Expand All @@ -309,14 +318,12 @@ def dtype(self):
class SearchedCase(Value):
"""Searched case statement."""

cases: VarTuple[Value[dt.Boolean]]
results: VarTuple[Value]
cases: Annotated[VarTuple[Value[dt.Boolean]], Length(at_least=1)]
results: Annotated[VarTuple[Value], Length(at_least=1)]
default: Value

def __init__(self, cases, results, default):
assert len(cases) == len(results)
if default.dtype.is_null():
default = Cast(default, rlz.highest_precedence_dtype(results))
super().__init__(cases=cases, results=results, default=default)

@attribute
Expand Down
12 changes: 6 additions & 6 deletions ibis/expr/operations/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def __init__(self, left, right):
"""
if not rlz.comparable(left, right):
raise IbisTypeError(
f"Arguments {rlz._arg_type_error_format(left)} and "
f"{rlz._arg_type_error_format(right)} are not comparable"
f"Arguments {rlz.arg_type_error_format(left)} and "
f"{rlz.arg_type_error_format(right)} are not comparable"
)
super().__init__(left=left, right=right)

Expand Down Expand Up @@ -121,13 +121,13 @@ class Between(Value):
def __init__(self, arg, lower_bound, upper_bound):
if not rlz.comparable(arg, lower_bound):
raise ValidationError(
f"Arguments {rlz._arg_type_error_format(arg)} and "
f"{rlz._arg_type_error_format(lower_bound)} are not comparable"
f"Arguments {rlz.arg_type_error_format(arg)} and "
f"{rlz.arg_type_error_format(lower_bound)} are not comparable"
)
if not rlz.comparable(arg, upper_bound):
raise ValidationError(
f"Arguments {rlz._arg_type_error_format(arg)} and "
f"{rlz._arg_type_error_format(upper_bound)} are not comparable"
f"Arguments {rlz.arg_type_error_format(arg)} and "
f"{rlz.arg_type_error_format(upper_bound)} are not comparable"
)
super().__init__(arg=arg, lower_bound=lower_bound, upper_bound=upper_bound)

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _promote_interval_resolution(units: list[IntervalUnit]) -> IntervalUnit:
raise AssertionError("unreachable")


def _arg_type_error_format(op):
def arg_type_error_format(op: ops.Value) -> str:
if isinstance(op, ops.Literal):
return f"Literal({op.value}):{op.dtype}"
else:
Expand Down
64 changes: 62 additions & 2 deletions ibis/tests/expr/test_case.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import pytest

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis import _
from ibis.common.annotations import SignatureValidationError
from ibis.tests.util import assert_equal, assert_pickle_roundtrip


Expand Down Expand Up @@ -44,6 +47,41 @@ def test_ifelse_function_deferred(table):
assert res.equals(sol)


def test_case_dshape(table):
assert isinstance(ibis.case().when(True, "bar").when(False, "bar").end(), ir.Scalar)
assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.Scalar)
assert isinstance(
ibis.case().when(table.b == 9, None).else_("bar").end(), ir.Column
)
assert isinstance(ibis.case().when(True, table.a).else_(42).end(), ir.Column)
assert isinstance(ibis.case().when(True, 42).else_(table.a).end(), ir.Column)
assert isinstance(ibis.case().when(True, table.a).else_(table.b).end(), ir.Column)

assert isinstance(ibis.literal(5).case().when(9, 42).end(), ir.Scalar)
assert isinstance(ibis.literal(5).case().when(9, 42).else_(43).end(), ir.Scalar)
assert isinstance(ibis.literal(5).case().when(table.a, 42).end(), ir.Column)
assert isinstance(ibis.literal(5).case().when(9, table.a).end(), ir.Column)
assert isinstance(ibis.literal(5).case().when(table.a, table.b).end(), ir.Column)
assert isinstance(
ibis.literal(5).case().when(9, 42).else_(table.a).end(), ir.Column
)
assert isinstance(table.a.case().when(9, 42).end(), ir.Column)
assert isinstance(table.a.case().when(table.b, 42).end(), ir.Column)
assert isinstance(table.a.case().when(9, table.b).end(), ir.Column)
assert isinstance(table.a.case().when(table.a, table.b).end(), ir.Column)


def test_case_dtype():
assert isinstance(
ibis.case().when(True, "bar").when(False, "bar").end(), ir.StringValue
)
assert isinstance(ibis.case().when(True, None).else_("bar").end(), ir.StringValue)
with pytest.raises(TypeError):
ibis.case().when(True, 5).when(False, "bar").end()
with pytest.raises(TypeError):
ibis.case().when(True, 5).else_("bar").end()


def test_simple_case_expr(table):
case1, result1 = "foo", table.a
case2, result2 = "bar", table.c
Expand Down Expand Up @@ -162,8 +200,6 @@ def test_multiple_case_null_else(table):
op = expr.op()
assert isinstance(expr, ir.StringColumn)
assert isinstance(op.default.to_expr(), ir.Value)
assert isinstance(op.default, ops.Cast)
cpcloud marked this conversation as resolved.
Show resolved Hide resolved
assert op.default.to == dt.string


def test_case_mixed_type():
Expand All @@ -177,3 +213,27 @@ def test_case_mixed_type():
)
result = t0[expr]
assert result["label"].type().equals(dt.string)


def test_err_on_nonbool_expr(table):
with pytest.raises(SignatureValidationError):
ibis.case().when(table.a, "bar").else_("baz").end()
with pytest.raises(SignatureValidationError):
ibis.case().when(ibis.literal(1), "bar").else_("baz").end()


def test_err_on_noncomparable(table):
# Can't compare an int to a string
with pytest.raises(TypeError):
table.a.case().when("foo", "bar").end()


def test_err_on_empty_cases(table):
with pytest.raises(SignatureValidationError):
ibis.case().end()
with pytest.raises(SignatureValidationError):
ibis.case().else_(42).end()
with pytest.raises(SignatureValidationError):
table.a.case().end()
with pytest.raises(SignatureValidationError):
table.a.case().else_(42).end()