From 666d98bc9b6a047b79cedd6d9ec38b1130217645 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Wed, 24 Jul 2024 11:53:01 -0800 Subject: [PATCH] feat: move from .case() to .cases() Fixes https://github.com/ibis-project/ibis/issues/7280 --- docs/posts/ci-analysis/index.qmd | 14 +- docs/tutorials/ibis-for-sql-users.qmd | 24 +- .../clickhouse/tests/test_operators.py | 14 +- ibis/backends/impala/tests/test_case_exprs.py | 4 +- ibis/backends/pandas/tests/test_operations.py | 8 +- ibis/backends/snowflake/tests/test_udf.py | 38 +-- ibis/backends/tests/sql/conftest.py | 4 +- .../test_case_in_projection/decompiled.py | 20 +- ibis/backends/tests/sql/test_select_sql.py | 4 +- ibis/backends/tests/test_aggregation.py | 10 +- ibis/backends/tests/test_conditionals.py | 100 +++++-- ibis/backends/tests/test_generic.py | 15 +- ibis/backends/tests/test_sql.py | 20 +- ibis/backends/tests/test_string.py | 18 +- ibis/backends/tests/test_struct.py | 2 +- ibis/backends/tests/tpc/h/test_queries.py | 24 +- ibis/expr/api.py | 88 +++--- ibis/expr/decompile.py | 16 +- ibis/expr/operations/logical.py | 2 +- ibis/expr/types/generic.py | 260 +++++++++--------- ibis/expr/types/numeric.py | 9 +- ibis/expr/types/relations.py | 4 +- ibis/tests/expr/test_case.py | 140 ++++------ ibis/tests/expr/test_value_exprs.py | 16 +- 24 files changed, 413 insertions(+), 441 deletions(-) diff --git a/docs/posts/ci-analysis/index.qmd b/docs/posts/ci-analysis/index.qmd index 5babc2c6d0c61..65159d5bbe9c7 100644 --- a/docs/posts/ci-analysis/index.qmd +++ b/docs/posts/ci-analysis/index.qmd @@ -203,14 +203,12 @@ Let's also give them some names that'll look nice on our plots. stats = stats.mutate( raw_improvements=_.has_poetry.cast("int") + _.has_team.cast("int") ).mutate( - improvements=( - _.raw_improvements.case() - .when(0, "None") - .when(1, "Poetry") - .when(2, "Poetry + Team Plan") - .else_("NA") - .end() - ), + improvements=_.raw_improvements.cases( + (0, "None"), + (1, "Poetry"), + (2, "Poetry + Team Plan"), + else_="NA", + ) team_plan=ibis.where(_.raw_improvements > 1, "Poetry + Team Plan", "None"), ) stats diff --git a/docs/tutorials/ibis-for-sql-users.qmd b/docs/tutorials/ibis-for-sql-users.qmd index 534090bfce649..6d8c9a556b407 100644 --- a/docs/tutorials/ibis-for-sql-users.qmd +++ b/docs/tutorials/ibis-for-sql-users.qmd @@ -473,11 +473,11 @@ semantics: case = ( t.one.cast("timestamp") .year() - .case() - .when(2015, "This year") - .when(2014, "Last year") - .else_("Earlier") - .end() + .cases( + (2015, "This year"), + (2014, "Last year"), + else_="Earlier", + ) ) expr = t.mutate(year_group=case) @@ -496,18 +496,16 @@ CASE END ``` -To do this, use `ibis.case`: +To do this, use `ibis.cases`: ```{python} -case = ( - ibis.case() - .when(t.two < 0, t.three * 2) - .when(t.two > 1, t.three) - .else_(t.two) - .end() +cases = ibis.cases( + (t.two < 0, t.three * 2), + (t.two > 1, t.three), + else_=t.two, ) -expr = t.mutate(cond_value=case) +expr = t.mutate(cond_value=cases) ibis.to_sql(expr) ``` diff --git a/ibis/backends/clickhouse/tests/test_operators.py b/ibis/backends/clickhouse/tests/test_operators.py index 4ca53a3d2b9f3..3ff07ce916a45 100644 --- a/ibis/backends/clickhouse/tests/test_operators.py +++ b/ibis/backends/clickhouse/tests/test_operators.py @@ -201,9 +201,7 @@ def test_ifelse(alltypes, df, op, pandas_op): def test_simple_case(con, alltypes, assert_sql): t = alltypes - expr = ( - t.string_col.case().when("foo", "bar").when("baz", "qux").else_("default").end() - ) + expr = t.string_col.cases(("foo", "bar"), ("baz", "qux"), else_="default") assert_sql(expr) assert len(con.execute(expr)) @@ -211,12 +209,10 @@ def test_simple_case(con, alltypes, assert_sql): def test_search_case(con, alltypes, assert_sql): t = alltypes - expr = ( - ibis.case() - .when(t.float_col > 0, t.int_col * 2) - .when(t.float_col < 0, t.int_col) - .else_(0) - .end() + expr = ibis.cases( + (t.float_col > 0, t.int_col * 2), + (t.float_col < 0, t.int_col), + else_=0, ) assert_sql(expr) diff --git a/ibis/backends/impala/tests/test_case_exprs.py b/ibis/backends/impala/tests/test_case_exprs.py index a195928b12214..360fbf9522c8b 100644 --- a/ibis/backends/impala/tests/test_case_exprs.py +++ b/ibis/backends/impala/tests/test_case_exprs.py @@ -14,13 +14,13 @@ def table(mockcon): @pytest.fixture def simple_case(table): - return table.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return table.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture def search_case(table): t = table - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture diff --git a/ibis/backends/pandas/tests/test_operations.py b/ibis/backends/pandas/tests/test_operations.py index b116995c22bdf..8fa1338132083 100644 --- a/ibis/backends/pandas/tests/test_operations.py +++ b/ibis/backends/pandas/tests/test_operations.py @@ -685,13 +685,7 @@ def test_summary_non_numeric(batting, batting_df): def test_non_range_index(): def do_replace(col): - return col.cases( - ( - (1, "one"), - (2, "two"), - ), - default="unk", - ) + return col.cases((1, "one"), (2, "two"), else_="unk") df = pd.DataFrame( { diff --git a/ibis/backends/snowflake/tests/test_udf.py b/ibis/backends/snowflake/tests/test_udf.py index 4a59013cebece..2ee68897f4419 100644 --- a/ibis/backends/snowflake/tests/test_udf.py +++ b/ibis/backends/snowflake/tests/test_udf.py @@ -8,7 +8,6 @@ import pytest from pytest import param -import ibis import ibis.expr.datatypes as dt from ibis import udf @@ -122,36 +121,23 @@ def predict_price( df.columns = ["CARAT_SCALED", "CUT_ENCODED", "COLOR_ENCODED", "CLARITY_ENCODED"] return model.predict(df) - def cases(value, mapping): - """This should really be a top-level function or method.""" - expr = ibis.case() - for k, v in mapping.items(): - expr = expr.when(value == k, v) - return expr.end() - diamonds = con.tables.DIAMONDS expr = diamonds.mutate( predicted_price=predict_price( (_.carat - _.carat.mean()) / _.carat.std(), - cases( - _.cut, - { - c: i - for i, c in enumerate( - ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 - ) - }, + _.cut.cases( + (c, i) + for i, c in enumerate( + ("Fair", "Good", "Very Good", "Premium", "Ideal"), start=1 + ) ), - cases(_.color, {c: i for i, c in enumerate("DEFGHIJ", start=1)}), - cases( - _.clarity, - { - c: i - for i, c in enumerate( - ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), - start=1, - ) - }, + _.color.cases((c, i) for i, c in enumerate("DEFGHIJ", start=1)), + _.clarity.cases( + (c, i) + for i, c in enumerate( + ("I1", "IF", "SI1", "SI2", "VS1", "VS2", "VVS1", "VVS2"), + start=1, + ) ), ) ) diff --git a/ibis/backends/tests/sql/conftest.py b/ibis/backends/tests/sql/conftest.py index 04667e60e033b..06de1c83c8c08 100644 --- a/ibis/backends/tests/sql/conftest.py +++ b/ibis/backends/tests/sql/conftest.py @@ -164,13 +164,13 @@ def difference(con): @pytest.fixture(scope="module") def simple_case(con): t = con.table("alltypes") - return t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() + return t.g.cases(("foo", "bar"), ("baz", "qux"), else_="default") @pytest.fixture(scope="module") def search_case(con): t = con.table("alltypes") - return ibis.case().when(t.f > 0, t.d * 2).when(t.c < 0, t.a * 2).end() + return ibis.cases((t.f > 0, t.d * 2), (t.c < 0, t.a * 2)) @pytest.fixture(scope="module") diff --git a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py index 0deb771897042..35fb932c2248f 100644 --- a/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py +++ b/ibis/backends/tests/sql/snapshots/test_select_sql/test_case_in_projection/decompiled.py @@ -22,18 +22,14 @@ lit2 = ibis.literal("bar") result = alltypes.select( - alltypes.g.case() - .when(lit, lit2) - .when(lit1, ibis.literal("qux")) - .else_(ibis.literal("default")) - .end() - .name("col1"), - ibis.case() - .when(alltypes.g == lit, lit2) - .when(alltypes.g == lit1, alltypes.g) - .else_(ibis.literal(None)) - .end() - .name("col2"), + alltypes.g.cases( + (lit, lit2), (lit1, ibis.literal("qux")), else_=ibis.literal("default") + ).name("col1"), + ibis.cases( + (alltypes.g == lit, lit2), + (alltypes.g == lit1, alltypes.g), + else_=ibis.literal(None), + ).name("col2"), alltypes.a, alltypes.b, alltypes.c, diff --git a/ibis/backends/tests/sql/test_select_sql.py b/ibis/backends/tests/sql/test_select_sql.py index 94a52017f763f..24893739fb6eb 100644 --- a/ibis/backends/tests/sql/test_select_sql.py +++ b/ibis/backends/tests/sql/test_select_sql.py @@ -397,8 +397,8 @@ def test_bool_bool(snapshot): def test_case_in_projection(alltypes, snapshot): t = alltypes - expr = t.g.case().when("foo", "bar").when("baz", "qux").else_("default").end() - expr2 = ibis.case().when(t.g == "foo", "bar").when(t.g == "baz", t.g).end() + expr = t.g.cases(("foo", "bar"), ("baz", "qux"), else_=("default")) + expr2 = ibis.cases((t.g == "foo", "bar"), (t.g == "baz", t.g)) expr = t[expr.name("col1"), expr2.name("col2"), t] snapshot.assert_match(to_sql(expr), "out.sql") diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index b95eaf676bfe3..68c99c9012414 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -655,7 +655,7 @@ def test_first_last(backend, alltypes, method, filtered): # To sanely test this we create a column that is a mix of nulls and a # single value (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -687,7 +687,7 @@ def test_arbitrary(backend, alltypes, df, filtered): # _something_ we create a column that is a mix of nulls and a single value # (or a single value after filtering is applied). if filtered: - new = alltypes.int_col.cases([(3, 30), (4, 40)]) + new = alltypes.int_col.cases((3, 30), (4, 40)) where = _.int_col == 3 else: new = (alltypes.int_col == 3).ifelse(30, None) @@ -1433,9 +1433,7 @@ def collect_udf(v): def test_binds_are_cast(alltypes): expr = alltypes.aggregate( - high_line_count=( - alltypes.string_col.case().when("1-URGENT", 1).else_(0).end().sum() - ) + high_line_count=alltypes.string_col.cases(("1-URGENT", 1), else_=0).sum() ) expr.execute() @@ -1481,7 +1479,7 @@ def test_agg_name_in_output_column(alltypes): def test_grouped_case(backend, con): table = ibis.memtable({"key": [1, 1, 2, 2], "value": [10, 30, 20, 40]}) - case_expr = ibis.case().when(table.value < 25, table.value).else_(ibis.null()).end() + case_expr = ibis.cases((table.value < 25, table.value), else_=ibis.null()) expr = ( table.group_by(k="key") diff --git a/ibis/backends/tests/test_conditionals.py b/ibis/backends/tests/test_conditionals.py index 367dae384a5ae..660acec8100bf 100644 --- a/ibis/backends/tests/test_conditionals.py +++ b/ibis/backends/tests/test_conditionals.py @@ -62,17 +62,12 @@ def test_substitute(backend): "inp, exp", [ pytest.param( - lambda: ibis.literal(1) - .case() - .when(1, "one") - .when(2, "two") - .else_("other") - .end(), + lambda: ibis.literal(1).cases((1, "one"), (2, "two"), else_="other"), "one", id="one_kwarg", ), pytest.param( - lambda: ibis.literal(5).case().when(1, "one").when(2, "two").end(), + lambda: ibis.literal(5).cases((1, "one"), (2, "two")), None, id="fallthrough", ), @@ -93,13 +88,8 @@ def test_value_cases_scalar(con, inp, exp): ) def test_value_cases_column(batting): df = batting.to_pandas() - expr = ( - batting.RBI.case() - .when(5, "five") - .when(4, "four") - .when(3, "three") - .else_("could be good?") - .end() + expr = batting.RBI.cases( + (5, "five"), (4, "four"), (3, "three"), else_="could be good?" ) result = expr.execute() expected = np.select( @@ -112,7 +102,7 @@ def test_value_cases_column(batting): def test_ibis_cases_scalar(): - expr = ibis.literal(5).case().when(5, "five").when(4, "four").end() + expr = ibis.literal(5).cases((5, "five"), (4, "four")) result = expr.execute() assert result == "five" @@ -125,12 +115,8 @@ def test_ibis_cases_scalar(): def test_ibis_cases_column(batting): t = batting df = batting.to_pandas() - expr = ( - ibis.case() - .when(t.RBI < 5, "really bad team") - .when(t.teamID == "PH1", "ph1 team") - .else_(t.teamID) - .end() + expr = ibis.cases( + (t.RBI < 5, "really bad team"), (t.teamID == "PH1", "ph1 team"), else_=t.teamID ) result = expr.execute() expected = np.select( @@ -145,5 +131,75 @@ def test_ibis_cases_column(batting): @pytest.mark.broken("clickhouse", reason="special case this and returns 'oops'") def test_value_cases_null(con): """CASE x WHEN NULL never gets hit""" - e = ibis.literal(5).nullif(5).case().when(None, "oops").else_("expected").end() + e = ibis.literal(5).nullif(5).cases((None, "oops"), else_="expected") assert con.execute(e) == "expected" + + +@pytest.mark.broken("pyspark", reason="raises a ResourceWarning that we can't catch") +def test_ibis_case_is_deprecated(con): + # just to make sure that the deprecated .case() method still works + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.case().when(True, "yes").end()) == "yes" + with pytest.warns(FutureWarning, match=".cases"): + assert pd.isna(con.execute(ibis.case().when(False, "yes").end())) + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.case().when(False, "yes").else_("no").end()) == "no" + + with pytest.warns(FutureWarning, match=".cases"): + assert con.execute(ibis.literal("a").case().when("a", "yes").end()) == "yes" + with pytest.warns(FutureWarning, match=".cases"): + assert pd.isna(con.execute(ibis.literal("a").case().when("b", "yes").end())) + with pytest.warns(FutureWarning, match=".cases"): + assert ( + con.execute(ibis.literal("a").case().when("b", "yes").else_("no").end()) + == "no" + ) + + +@pytest.mark.parametrize( + "inp, exp", + [ + pytest.param( + lambda: ibis.literal(1).cases([(1, "one"), (2, "two")], "other"), + "one", + id="basic", + ), + pytest.param( + lambda: ibis.literal(1).cases([(1, "one"), (2, "two")], default="other"), + "one", + id="one_kwarg", + ), + pytest.param( + lambda: ibis.literal(1).cases( + case_result_pairs=[(1, "one"), (2, "two")], default="other" + ), + "one", + id="two_kwargs", + ), + pytest.param( + lambda: ibis.literal(1).cases( + default="other", case_result_pairs=[(1, "one"), (2, "two")] + ), + "one", + id="two_kwargs_swapped", + ), + pytest.param( + lambda: ibis.literal(5).cases([(1, "one"), (2, "two")], "other"), + "other", + id="other", + ), + pytest.param( + lambda: ibis.literal(5).cases([(1, "one"), (2, "two")]), + None, + id="fallthrough", + ), + ], +) +def test_value_cases_old_api_is_deprecated(con, inp, exp): + with pytest.warns(FutureWarning): + i = inp() + result = con.execute(i) + if exp is None: + assert pd.isna(result) + else: + assert result == exp diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index f63d237323fdc..2017c87891998 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -390,12 +390,11 @@ def test_case_where(backend, alltypes, df): table = alltypes table = table.mutate( new_col=( - ibis.case() - .when(table["int_col"] == 1, 20) - .when(table["int_col"] == 0, 10) - .else_(0) - .end() - .cast("int64") + ibis.cases( + (table["int_col"] == 1, 20), + (table["int_col"] == 0, 10), + else_=0, + ).cast("int64") ) ) @@ -428,9 +427,7 @@ def test_select_filter_mutate(backend, alltypes, df): # Prepare the float_col so that filter must execute # before the cast to get the correct result. - t = t.mutate( - float_col=ibis.case().when(t["bool_col"], t["float_col"]).else_(np.nan).end() - ) + t = t.mutate(float_col=ibis.cases((t["bool_col"], t["float_col"]), else_=np.nan)) # Actual test t = t[t.columns] diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index 777cfa3db8bb3..c0bfb53f18170 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -59,16 +59,16 @@ def test_group_by_has_index(backend, snapshot): ) expr = countries.group_by( cont=( - _.continent.case() - .when("NA", "North America") - .when("SA", "South America") - .when("EU", "Europe") - .when("AF", "Africa") - .when("AS", "Asia") - .when("OC", "Oceania") - .when("AN", "Antarctica") - .else_("Unknown continent") - .end() + _.continent.cases( + ("NA", "North America"), + ("SA", "South America"), + ("EU", "Europe"), + ("AF", "Africa"), + ("AS", "Asia"), + ("OC", "Oceania"), + ("AN", "Antarctica"), + else_="Unknown continent", + ) ) ).agg(total_pop=_.population.sum()) sql = str(ibis.to_sql(expr, dialect=backend.name())) diff --git a/ibis/backends/tests/test_string.py b/ibis/backends/tests/test_string.py index ceb9fdc77711b..e2352055ea9e7 100644 --- a/ibis/backends/tests/test_string.py +++ b/ibis/backends/tests/test_string.py @@ -507,9 +507,9 @@ def uses_java_re(t): id="length", ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").startswith( - "abc" - ), + lambda t: t.int_col.cases( + (1, "abcd"), (2, "ABCD"), else_="dabc" + ).startswith("abc"), lambda t: t.int_col == 1, id="startswith", marks=[ @@ -517,7 +517,7 @@ def uses_java_re(t): ], ), param( - lambda t: t.int_col.cases([(1, "abcd"), (2, "ABCD")], "dabc").endswith( + lambda t: t.int_col.cases((1, "abcd"), (2, "ABCD"), else_="dabc").endswith( "bcd" ), lambda t: t.int_col == 1, @@ -693,11 +693,9 @@ def test_re_replace_global(con): @pytest.mark.notimpl(["druid"], raises=ValidationError) def test_substr_with_null_values(backend, alltypes, df): table = alltypes.mutate( - substr_col_null=ibis.case() - .when(alltypes["bool_col"], alltypes["string_col"]) - .else_(None) - .end() - .substr(0, 2) + substr_col_null=ibis.cases( + (alltypes["bool_col"], alltypes["string_col"]), else_=None + ).substr(0, 2) ) result = table.execute() @@ -910,7 +908,7 @@ def test_levenshtein(con, right): @pytest.mark.parametrize( "expr", [ - param(ibis.case().when(True, "%").end(), id="case"), + param(ibis.cases((True, "%")), id="case"), param(ibis.ifelse(True, "%", ibis.null()), id="ifelse"), ], ) diff --git a/ibis/backends/tests/test_struct.py b/ibis/backends/tests/test_struct.py index 8a7ed89144533..0e6f72a486dfc 100644 --- a/ibis/backends/tests/test_struct.py +++ b/ibis/backends/tests/test_struct.py @@ -155,7 +155,7 @@ def test_collect_into_struct(alltypes): @pytest.mark.notimpl(["flink"], raises=Py4JJavaError, reason="not implemented in ibis") def test_field_access_after_case(con): s = ibis.struct({"a": 3}) - x = ibis.case().when(True, s).else_(ibis.struct({"a": 4})).end() + x = ibis.cases((True, s), else_=ibis.struct({"a": 4})) y = x.a assert con.to_pandas(y) == 3 diff --git a/ibis/backends/tests/tpc/h/test_queries.py b/ibis/backends/tests/tpc/h/test_queries.py index cb549cdd0fe40..de94216eecd3d 100644 --- a/ibis/backends/tests/tpc/h/test_queries.py +++ b/ibis/backends/tests/tpc/h/test_queries.py @@ -272,9 +272,7 @@ def test_08(part, supplier, region, lineitem, orders, customer, nation): ] ) - q = q.mutate( - nation_volume=ibis.case().when(q.nation == NATION, q.volume).else_(0).end() - ) + q = q.mutate(nation_volume=ibis.cases((q.nation == NATION, q.volume), else_=0)) gq = q.group_by([q.o_year]) q = gq.aggregate(mkt_share=q.nation_volume.sum() / q.volume.sum()) q = q.order_by([q.o_year]) @@ -400,19 +398,15 @@ def test_12(orders, lineitem): gq = q.group_by([q.l_shipmode]) q = gq.aggregate( - high_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 1) - .when("2-HIGH", 1) - .else_(0) - .end() + high_line_count=q.o_orderpriority.cases( + ("1-URGENT", 1), + ("2-HIGH", 1), + else_=0, ).sum(), - low_line_count=( - q.o_orderpriority.case() - .when("1-URGENT", 0) - .when("2-HIGH", 0) - .else_(1) - .end() + low_line_count=q.o_orderpriority.cases( + ("1-URGENT", 0), + ("2-HIGH", 0), + else_=1, ).sum(), ) q = q.order_by(q.l_shipmode) diff --git a/ibis/expr/api.py b/ibis/expr/api.py index f8aca6d689392..92e1afefe0349 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -67,6 +67,7 @@ "array", "asc", "case", + "cases", "coalesce", "connect", "cross_join", @@ -1108,56 +1109,71 @@ def interval( return functools.reduce(operator.add, intervals) +@util.deprecated(instead="use ibis.cases() instead", as_of="9.1") def case() -> bl.SearchedCaseBuilder: - """Begin constructing a case expression. + """DEPRECATED: Use `ibis.cases()` instead.""" + return bl.SearchedCaseBuilder() + + +@deferrable +def cases(*branches: tuple[Any, Any], else_: Any | None = None) -> ir.Value: + """Create a multi-branch if-else expression. - Use the `.when` method on the resulting object followed by `.end` to create a - complete case expression. + Goes through each (condition, value) pair in `branches`, finding the + first condition that evaluates to True, and returns the corresponding + value. If no condition is True, returns `else_`. Returns ------- - SearchedCaseBuilder - A builder object to use for constructing a case expression. + Value + A value expression See Also -------- - [`Value.case()`](./expression-generic.qmd#ibis.expr.types.generic.Value.case) + [`Value.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) Examples -------- >>> import ibis - >>> from ibis import _ >>> ibis.options.interactive = True - >>> t = ibis.memtable( - ... { - ... "left": [1, 2, 3, 4], - ... "symbol": ["+", "-", "*", "/"], - ... "right": [5, 6, 7, 8], - ... } - ... ) - >>> t.mutate( - ... result=( - ... ibis.case() - ... .when(_.symbol == "+", _.left + _.right) - ... .when(_.symbol == "-", _.left - _.right) - ... .when(_.symbol == "*", _.left * _.right) - ... .when(_.symbol == "/", _.left / _.right) - ... .end() - ... ) - ... ) - ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ - ┃ left ┃ symbol ┃ right ┃ result ┃ - ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ - │ int64 │ string │ int64 │ float64 │ - ├───────┼────────┼───────┼─────────┤ - │ 1 │ + │ 5 │ 6.0 │ - │ 2 │ - │ 6 │ -4.0 │ - │ 3 │ * │ 7 │ 21.0 │ - │ 4 │ / │ 8 │ 0.5 │ - └───────┴────────┴───────┴─────────┘ - + >>> v = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}).values + >>> ibis.cases((v == 1, "a"), (v > 2, "b"), else_="unk").name("cases") + ┏━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━┩ + │ string │ + ├────────┤ + │ a │ + │ unk │ + │ a │ + │ unk │ + │ b │ + │ unk │ + │ b │ + └────────┘ + >>> ibis.cases( + ... (v % 2 == 0, "divisible by 2"), + ... (v % 3 == 0, "divisible by 3"), + ... (v % 4 == 0, "shadowed by the 2 case"), + ... ).name("cases") + ┏━━━━━━━━━━━━━━━━┓ + ┃ cases ┃ + ┡━━━━━━━━━━━━━━━━┩ + │ string │ + ├────────────────┤ + │ NULL │ + │ divisible by 2 │ + │ NULL │ + │ divisible by 2 │ + │ divisible by 3 │ + │ divisible by 2 │ + │ divisible by 2 │ + └────────────────┘ """ - return bl.SearchedCaseBuilder() + if not branches: + raise ValueError("At least one branch is required") + cases, results = zip(*branches) + return ops.SearchedCase(cases=cases, results=results, default=else_).to_expr() def now() -> ir.TimestampScalar: diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 7d87550a9bcf6..2d73d35b2570c 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -304,16 +304,12 @@ def ifelse(op, bool_expr, true_expr, false_null_expr): @translate.register(ops.SimpleCase) @translate.register(ops.SearchedCase) -def switch_case(op, cases, results, default, base=None): - out = f"{base}.case()" if base else "ibis.case()" - - for case, result in zip(cases, results): - out = f"{out}.when({case}, {result})" - - if default is not None: - out = f"{out}.else_({default})" - - return f"{out}.end()" +def switch_cases(op, cases, results, default, base=None): + namespace = f"{base}" if base else "ibis" + case_strs = [f"({case}, {result})" for case, result in zip(cases, results)] + cases_str = ", ".join(case_strs) + else_str = f", else_={default}" if default is not None else "" + return f"{namespace}.cases({cases_str}{else_str})" _infix_ops = { diff --git a/ibis/expr/operations/logical.py b/ibis/expr/operations/logical.py index bc033f66318ed..7ea03f4d70e85 100644 --- a/ibis/expr/operations/logical.py +++ b/ibis/expr/operations/logical.py @@ -154,7 +154,7 @@ class IfElse(Value): Equivalent to ```python - bool_expr.case().when(True, true_expr).else_(false_or_null_expr) + bool_expr.cases((True, true_expr), else_=false_or_null_expr) ``` Many backends implement this as a built-in function. diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index d9b087fbd4f19..4eb976d7f2be1 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1,6 +1,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence +import warnings +from collections.abc import Sequence from typing import TYPE_CHECKING, Any from public import public @@ -10,6 +11,7 @@ import ibis.expr.builders as bl import ibis.expr.datatypes as dt import ibis.expr.operations as ops +from ibis import util from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton from ibis.expr.rewrites import rewrite_window_input @@ -723,19 +725,18 @@ def substitute( └────────┴──────────────┘ """ if isinstance(value, dict): - expr = ibis.case() - try: - null_replacement = value.pop(None) - except KeyError: - pass - else: - expr = expr.when(self.isnull(), null_replacement) - for k, v in value.items(): - expr = expr.when(self == k, v) + branches = list(value.items()) else: - expr = self.case().when(value, replacement) - - return expr.else_(else_ if else_ is not None else self).end() + branches = [(value, replacement)] + nulls = [(k, v) for k, v in branches if k is None] + nonnulls = [(k, v) for k, v in branches if k is not None] + if nulls: + null_replacement = nulls[0][1] + self = self.fill_null(null_replacement) + else_ = else_ if else_ is not None else self + if not nonnulls: + return else_ + return self.cases(*nonnulls, else_=else_) def over( self, @@ -871,99 +872,80 @@ def notnull(self) -> ir.BooleanValue: """ return ops.NotNull(self).to_expr() + @util.deprecated(instead="use Value.cases() instead", as_of="9.1") def case(self) -> bl.SimpleCaseBuilder: - """Create a SimpleCaseBuilder to chain multiple if-else statements. - - Add new search expressions with the `.when()` method. These must be - comparable with this column expression. Conclude by calling `.end()`. - - Returns - ------- - SimpleCaseBuilder - A case builder - - See Also - -------- - [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + """DEPRECATED: Use `self.cases()` instead.""" + return bl.SimpleCaseBuilder(self.op()) - Examples - -------- - >>> import ibis - >>> ibis.options.interactive = True - >>> x = ibis.examples.penguins.fetch().head(5)["sex"] - >>> x - ┏━━━━━━━━┓ - ┃ sex ┃ - ┡━━━━━━━━┩ - │ string │ - ├────────┤ - │ male │ - │ female │ - │ female │ - │ NULL │ - │ female │ - └────────┘ - >>> x.case().when("male", "M").when("female", "F").else_("U").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male', 'female'), ('M', 'F'), 'U') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├──────────────────────────────────────────────────────┤ - │ M │ - │ F │ - │ F │ - │ U │ - │ F │ - └──────────────────────────────────────────────────────┘ - - Cases not given result in the ELSE case - - >>> x.case().when("male", "M").else_("OTHER").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male',), ('M',), 'OTHER') ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├─────────────────────────────────────────────┤ - │ M │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - │ OTHER │ - └─────────────────────────────────────────────┘ - - If you don't supply an ELSE, then NULL is used - - >>> x.case().when("male", "M").end() - ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ - ┃ SimpleCase(sex, ('male',), ('M',), Cast(None, string)) ┃ - ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ - │ string │ - ├────────────────────────────────────────────────────────┤ - │ M │ - │ NULL │ - │ NULL │ - │ NULL │ - │ NULL │ - └────────────────────────────────────────────────────────┘ - """ - import ibis.expr.builders as bl + @staticmethod + def _norm_cases_args(*args, **kwargs): + # TODO: remove in v10.0 once we have a deprecation cycle + # before, the API for Value.cases() was + # def cases( + # self, + # case_result_pairs: Iterable[tuple[Value, Value]], + # default: Value | None = None, + # ) -> Value: + # Now it is + # def cases( + # self, + # *branches: tuple[Value, Value], + # else_: Value | None = None, + # ) -> Value: + # This method normalizes the arguments to the new API. + using_old_api = False + branches = [] + else_ = None + if len(args) >= 1: + first_arg = args[0] + first_arg = util.promote_list(first_arg) + if len(first_arg) > 0 and isinstance(first_arg[0], tuple): + # called as .cases([(test, result), ...], ) + using_old_api = True + branches = first_arg + else_ = args[1] if len(args) == 2 else None + else: + # called as .cases((test, result), ...) + branches = list(args) + + if "case_result_pairs" in kwargs: + using_old_api = True + branches = list(kwargs["case_result_pairs"]) + elif "branches" in kwargs: + branches = list(kwargs["branches"]) + + if "default" in kwargs: + using_old_api = True + else_ = kwargs["default"] + elif "else_" in kwargs: + else_ = kwargs["else_"] + + if using_old_api: + warnings.warn( + "You are using the old API for `cases()`. Please see" + " https://ibis-project.org/reference/expression-generic" + " on how to upgrade to the new API.", + FutureWarning, + ) + return branches, else_ - return bl.SimpleCaseBuilder(self.op()) + def cases(self, *args, **kwargs) -> Value: # noqa: D417 + """Create a multi-branch if-else expression. - def cases( - self, - case_result_pairs: Iterable[tuple[ir.BooleanValue, Value]], - default: Value | None = None, - ) -> Value: - """Create a case expression in one shot. + This is semantically equivalent to + CASE self + WHEN test_val0 THEN result0 + WHEN test_val1 THEN result1 + ELSE else_ + END Parameters ---------- - case_result_pairs - Conditional-result pairs - default + branches + (test_val, result) pairs. We look through the test values in order + and return the result corresponding to the first test value that + matches `self`. If none match, we return `else_`. + else_ Value to return if none of the case conditions are true Returns @@ -974,48 +956,56 @@ def cases( See Also -------- [`Value.substitute()`](./expression-generic.qmd#ibis.expr.types.generic.Value.substitute) - [`ibis.cases()`](./expression-generic.qmd#ibis.expr.types.generic.Value.cases) - [`ibis.case()`](./expression-generic.qmd#ibis.case) + [`ibis.cases()`](./expression-generic.qmd#ibis.cases) Examples -------- >>> import ibis >>> ibis.options.interactive = True - >>> t = ibis.memtable({"values": [1, 2, 1, 2, 3, 2, 4]}) - >>> t - ┏━━━━━━━━┓ - ┃ values ┃ - ┡━━━━━━━━┩ - │ int64 │ - ├────────┤ - │ 1 │ - │ 2 │ - │ 1 │ - │ 2 │ - │ 3 │ - │ 2 │ - │ 4 │ - └────────┘ - >>> number_letter_map = ((1, "a"), (2, "b"), (3, "c")) - >>> t.values.cases(number_letter_map, default="unk").name("replace") - ┏━━━━━━━━━┓ - ┃ replace ┃ - ┡━━━━━━━━━┩ - │ string │ - ├─────────┤ - │ a │ - │ b │ - │ a │ - │ b │ - │ c │ - │ b │ - │ unk │ - └─────────┘ + >>> t = ibis.memtable( + ... { + ... "left": [5, 6, 7, 8, 9, 10], + ... "symbol": ["+", "-", "*", "/", "bogus", None], + ... "right": [1, 2, 3, 4, 5, 6], + ... } + ... ) + + Note we never hit the `None` case, because `x = NULL` is always NULL, + which is not truthy. If you want to replace NULLs, you should use + `.fillna(-999)` prior to `cases()`. + + >>> t.mutate( + ... result=( + ... t.symbol.cases( + ... ("+", t.left + t.right), + ... ("-", t.left - t.right), + ... ("*", t.left * t.right), + ... ("/", t.left / t.right), + ... (None, -999), + ... ) + ... ) + ... ) + ┏━━━━━━━┳━━━━━━━━┳━━━━━━━┳━━━━━━━━━┓ + ┃ left ┃ symbol ┃ right ┃ result ┃ + ┡━━━━━━━╇━━━━━━━━╇━━━━━━━╇━━━━━━━━━┩ + │ int64 │ string │ int64 │ float64 │ + ├───────┼────────┼───────┼─────────┤ + │ 5 │ + │ 1 │ 6.0 │ + │ 6 │ - │ 2 │ 4.0 │ + │ 7 │ * │ 3 │ 21.0 │ + │ 8 │ / │ 4 │ 2.0 │ + │ 9 │ bogus │ 5 │ NULL │ + │ 10 │ NULL │ 6 │ NULL │ + └───────┴────────┴───────┴─────────┘ """ - builder = self.case() - for case, result in case_result_pairs: - builder = builder.when(case, result) - return builder.else_(default).end() + branches, else_ = self._norm_cases_args(*args, **kwargs) + + if not branches: + raise ValueError("At least one branch is required") + cases, results = zip(*branches) + return ops.SimpleCase( + base=self, cases=cases, results=results, default=else_ + ).to_expr() def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: """Aggregate this expression's elements into an array. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index c99c357c74703..c4a300ccf532a 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -1,6 +1,5 @@ from __future__ import annotations -import functools from typing import TYPE_CHECKING, Literal from public import public @@ -1149,13 +1148,7 @@ def label(self, labels: Iterable[str], nulls: str | None = None) -> ir.StringVal │ 2 │ c │ └───────┴─────────┘ """ - return ( - functools.reduce( - lambda stmt, inputs: stmt.when(*inputs), enumerate(labels), self.case() - ) - .else_(nulls) - .end() - ) + return self.cases(*enumerate(labels), else_=nulls) @public diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index 20b7469a1a177..ca24940db6330 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -2941,9 +2941,7 @@ def info(self) -> Table: for pos, colname in enumerate(self.columns): col = self[colname] typ = col.type() - agg = self.select( - isna=ibis.case().when(col.isnull(), 1).else_(0).end() - ).agg( + agg = self.select(isna=ibis.cases((col.isnull(), 1), else_=0)).agg( name=lit(colname), type=lit(str(typ)), nullable=lit(typ.nullable), diff --git a/ibis/tests/expr/test_case.py b/ibis/tests/expr/test_case.py index dbd0b9d21746f..3d34e0c5c272b 100644 --- a/ibis/tests/expr/test_case.py +++ b/ibis/tests/expr/test_case.py @@ -8,7 +8,7 @@ import ibis.expr.types as ir from ibis import _ from ibis.common.annotations import SignatureValidationError -from ibis.tests.util import assert_equal, assert_pickle_roundtrip +from ibis.tests.util import assert_pickle_roundtrip def test_ifelse_method(table): @@ -82,72 +82,43 @@ def test_case_dtype(): ibis.case().when(True, 5).else_("bar").end() -def test_simple_case_expr(table): - case1, result1 = "foo", table.a - case2, result2 = "bar", table.c - default_result = table.b - - expr1 = table.g.lower().cases( - [(case1, result1), (case2, result2)], default=default_result - ) - - expr2 = ( - table.g.lower() - .case() - .when(case1, result1) - .when(case2, result2) - .else_(default_result) - .end() - ) - - assert_equal(expr1, expr2) - assert isinstance(expr1, ir.IntegerColumn) - - def test_multiple_case_expr(table): - expr = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(table.d) - .end() + expr = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=table.d, ) # deferred cases - deferred = ( - ibis.case() - .when(_.a == 5, table.f) - .when(_.b == 128, table.b * 2) - .when(_.c == 1000, table.e) - .else_(table.d) - .end() + deferred = ibis.cases( + (_.a == 5, table.f), + (_.b == 128, table.b * 2), + (_.c == 1000, table.e), + else_=table.d, ) expr2 = deferred.resolve(table) # deferred results - expr3 = ( - ibis.case() - .when(table.a == 5, _.f) - .when(table.b == 128, _.b * 2) - .when(table.c == 1000, _.e) - .else_(table.d) - .end() - .resolve(table) - ) + expr3 = ibis.cases( + (table.a == 5, _.f), + (table.b == 128, _.b * 2), + (table.c == 1000, _.e), + else_=table.d, + ).resolve(table) # deferred default - expr4 = ( - ibis.case() - .when(table.a == 5, table.f) - .when(table.b == 128, table.b * 2) - .when(table.c == 1000, table.e) - .else_(_.d) - .end() - .resolve(table) + expr4 = ibis.cases( + (table.a == 5, table.f), + (table.b == 128, table.b * 2), + (table.c == 1000, table.e), + else_=_.d, + ).resolve(table) + + assert ( + repr(deferred) + == "cases(((_.a == 5), ), ((_.b == 128), ), ((_.c == 1000), ), else_=)" ) - - assert repr(deferred) == "" assert expr.equals(expr2) assert expr.equals(expr3) assert expr.equals(expr4) @@ -168,13 +139,11 @@ def test_pickle_multiple_case_node(table): result3 = table.e default = table.d - expr = ( - ibis.case() - .when(case1, result1) - .when(case2, result2) - .when(case3, result3) - .else_(default) - .end() + expr = ibis.cases( + (case1, result1), + (case2, result2), + (case3, result3), + else_=default, ) op = expr.op() @@ -182,7 +151,7 @@ def test_pickle_multiple_case_node(table): def test_simple_case_null_else(table): - expr = table.g.case().when("foo", "bar").end() + expr = table.g.cases(("foo", "bar")) op = expr.op() assert isinstance(expr, ir.StringColumn) @@ -192,8 +161,8 @@ def test_simple_case_null_else(table): def test_multiple_case_null_else(table): - expr = ibis.case().when(table.g == "foo", "bar").end() - expr2 = ibis.case().when(table.g == "foo", _).end().resolve("bar") + expr = ibis.cases((table.g == "foo", "bar")) + expr2 = ibis.cases((table.g == "foo", _)).resolve("bar") assert expr.equals(expr2) @@ -208,32 +177,43 @@ def test_case_mixed_type(): name="my_data", ) - expr = ( - t0.three.case().when(0, "low").when(1, "high").else_("null").end().name("label") - ) + expr = t0.three.cases((0, "low"), (1, "high"), else_="null").name("label") result = t0[expr] assert result["label"].type().equals(dt.string) +def test_err_on_bad_args(table): + with pytest.raises(ValueError): + ibis.cases((True,)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3, 4)) + with pytest.raises(ValueError): + ibis.cases((True, 3), 5) + + def test_err_on_nonbool_expr(table): with pytest.raises(SignatureValidationError): - ibis.case().when(table.a, "bar").else_("baz").end() + ibis.cases((table.a, "bar"), else_="baz") with pytest.raises(SignatureValidationError): - ibis.case().when(ibis.literal(1), "bar").else_("baz").end() + ibis.cases((ibis.literal(1), "bar"), else_=("baz")) def test_err_on_noncomparable(table): + table.a.cases((8, "bar")) + table.a.cases((-8, "bar")) # Can't compare an int to a string with pytest.raises(TypeError): - table.a.case().when("foo", "bar").end() + table.a.cases(("foo", "bar")) def test_err_on_empty_cases(table): - with pytest.raises(SignatureValidationError): - ibis.case().end() - with pytest.raises(SignatureValidationError): - ibis.case().else_(42).end() - with pytest.raises(SignatureValidationError): - table.a.case().end() - with pytest.raises(SignatureValidationError): - table.a.case().else_(42).end() + with pytest.raises(ValueError): + ibis.cases() + with pytest.raises(ValueError): + ibis.cases(else_=42) + with pytest.raises(ValueError): + table.a.cases() + with pytest.raises(ValueError): + table.a.cases(else_=42) diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 0b493dd749c2e..19d29eab62eb3 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -834,23 +834,11 @@ def test_substitute_dict(): subs = {"a": "one", "b": table.bar} result = table.foo.substitute(subs) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(table.foo) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=table.foo) assert_equal(result, expected) result = table.foo.substitute(subs, else_=ibis.null()) - expected = ( - ibis.case() - .when(table.foo == "a", "one") - .when(table.foo == "b", table.bar) - .else_(ibis.null()) - .end() - ) + expected = table.foo.cases(("a", "one"), ("b", table.bar), else_=ibis.null()) assert_equal(result, expected)