From a1b0b83f49e4dd74f3fd3a6237113699e65705d0 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Wed, 13 Mar 2024 12:10:40 -0800 Subject: [PATCH] feat: support all dtypes in MapGet and MapContains --- ibis/backends/tests/test_map.py | 117 ++++++++++++++++++++++++++++++++ ibis/expr/operations/maps.py | 4 +- 2 files changed, 119 insertions(+), 2 deletions(-) diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index a80945653325f..6f74c1270d534 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -186,6 +186,123 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): backend.assert_series_equal(result, expected) +@pytest.fixture( + params=[ + pytest.param( + ["a", "b"], + id="string", + ), + pytest.param( + [1, 2], + marks=pytest.mark.notyet("clickhouse", reason="only supports string keys"), + id="int", + ), + pytest.param( + [1.0, 2.0], + marks=pytest.mark.notyet("clickhouse", reason="only supports string keys"), + id="float", + ), + pytest.param( + [True, False], + marks=pytest.mark.notyet("clickhouse", reason="only supports string keys"), + id="bool", + ), + pytest.param( + [ibis.date(1, 2, 3), ibis.date(4, 5, 6)], + marks=[ + pytest.mark.notyet("clickhouse", reason="only supports string keys"), + pytest.mark.notyet(["pandas", "dask"]), + ], + id="date", + ), + pytest.param( + [[1, 2], [3, 4]], + marks=[ + pytest.mark.notyet("clickhouse", reason="only supports string keys"), + pytest.mark.notyet(["pandas", "dask"]), + ], + id="array", + ), + pytest.param( + [ibis.struct(dict(a=1)), ibis.struct(dict(a=2))], + marks=[ + pytest.mark.notyet("clickhouse", reason="only supports string keys"), + pytest.mark.notyet(["pandas", "dask"]), + ], + id="struct", + ), + pytest.param( + [ibis.timestamp("2021-01-01"), ibis.timestamp("2021-01-02")], + marks=[ + pytest.mark.notyet("clickhouse", reason="only supports string keys"), + ], + id="timestamp", + ), + ] +) +def keys(request): + return request.param + + +@pytest.fixture( + params=[ + pytest.param([1, 2], id="int"), + pytest.param([1.0, 2.0], id="float"), + pytest.param([True, False], id="bool"), + pytest.param(["a", "b"], id="string"), + pytest.param([[1, 2], [3, 4]], id="array"), + pytest.param( + [ibis.struct(dict(a=1)), ibis.struct(dict(a=2))], + id="struct", + ), + pytest.param( + [ibis.date(1, 2, 3), ibis.date(4, 5, 6)], + marks=pytest.mark.notimpl( + ["pandas", "dask"], reason="DateFromYMD isn't implemented" + ), + id="date", + ), + pytest.param( + [ibis.timestamp("2021-01-01"), ibis.timestamp("2021-01-02")], id="timestamp" + ), + ] +) +def values(request): + return request.param + + +@pytest.mark.notyet( + ["postgres", "risingwave"], + reason="only support maps of string -> string", +) +def test_map_get_all_types(con, keys, values): + m = ibis.map(ibis.array(keys), ibis.array(values)) + for key, val in zip(keys, values): + if isinstance(val, ibis.Expr): + val = con.execute(val) + assert con.execute(m[key]) == val + + +@pytest.mark.notyet( + ["postgres", "risingwave"], + reason="only support maps of string -> string", +) +def test_map_contains_all_types(con, keys, values): + m = ibis.map(ibis.array(keys), ibis.array(values)) + for key in keys: + assert con.execute(m.contains(key)) + + +# TODO: this should actually error: https://github.com/ibis-project/ibis/issues/8605 +def test_map_get_bad_key_type(): + ibis.map({"1": "a"})[2] + + +# TODO: this should actually error: https://github.com/ibis-project/ibis/issues/8605 +def test_map_contains_bad_key_type(): + ibis.map({"1": "a"}).contains(2) + + @pytest.mark.notimpl( ["risingwave"], raises=PsycoPg2InternalError, diff --git a/ibis/expr/operations/maps.py b/ibis/expr/operations/maps.py index 8323d880e6dc4..edd4d0505c90e 100644 --- a/ibis/expr/operations/maps.py +++ b/ibis/expr/operations/maps.py @@ -32,7 +32,7 @@ class MapLength(Unary): @public class MapGet(Value): arg: Value[dt.Map] - key: Value[dt.String | dt.Integer] + key: Value default: Value = None shape = rlz.shape_like("args") @@ -45,7 +45,7 @@ def dtype(self): @public class MapContains(Value): arg: Value[dt.Map] - key: Value[dt.String | dt.Integer] + key: Value shape = rlz.shape_like("args") dtype = dt.bool