diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index 562e0753bb03..81fe6c9e82d6 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -546,7 +546,11 @@ def visit_CountDistinct(self, op, *, arg, where): return self.f.count(sge.Distinct(expressions=[arg])) def visit_MapContains(self, op: ops.MapContains, *, arg, key): - return self.f.array_contains(self.f.map_keys(arg), key) + key_type = op.arg.dtype.key_type + return self.f.array_contains( + self.cast(self.f.map_keys(arg), dt.Array(value_type=key_type)), + self.cast(key, key_type), + ) def visit_Map(self, op: ops.Map, *, keys, values): return self.cast(self.f.map_from_arrays(keys, values), op.dtype) diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index 4437e21df16e..8b980aa9aa8f 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -198,7 +198,7 @@ def visit_RegexSplit(self, op, *, arg, pattern): def visit_Map(self, op, *, keys, values): return self.if_( sg.and_(self.f.is_array(keys), self.f.is_array(values)), - self.f.udf.object_from_arrays(keys, values), + self.f.arrays_to_object(keys, values), NULL, ) diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index a80945653325..28ac4010ab82 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -10,7 +10,7 @@ import ibis import ibis.common.exceptions as exc import ibis.expr.datatypes as dt -from ibis.backends.tests.errors import PsycoPg2InternalError, Py4JJavaError +from ibis.backends.tests.errors import Py4JJavaError pytestmark = [ pytest.mark.never( @@ -25,6 +25,19 @@ ), ] +mark_notyet_postgres = pytest.mark.notyet( + "postgres", reason="only supports string -> string" +) + +mark_notyet_snowflake = pytest.mark.notyet( + "snowflake", reason="map keys must be strings" +) + +mark_notimpl_risingwave_hstore = pytest.mark.notimpl( + ["risingwave"], + reason="function hstore(character varying[], character varying[]) does not exist", +) + @pytest.mark.notimpl(["pandas", "dask"]) def test_map_table(backend): @@ -37,11 +50,7 @@ def test_map_table(backend): @pytest.mark.xfail_version( duckdb=["duckdb<0.8.0"], raises=exc.UnsupportedOperationError ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_column_map_values(backend): table = backend.map expr = table.select("idx", vals=table.kv.values()).order_by("idx") @@ -67,11 +76,7 @@ def test_column_map_merge(backend): tm.assert_series_equal(result, expected) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_literal_map_keys(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.keys().name("tmp") @@ -82,11 +87,7 @@ def test_literal_map_keys(con): assert np.array_equal(result, ["1", "2"]) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_literal_map_values(con): mapping = ibis.literal({"1": "a", "2": "b"}) expr = mapping.values().name("tmp") @@ -95,7 +96,8 @@ def test_literal_map_values(con): assert np.array_equal(result, ["a", "b"]) -@pytest.mark.notimpl(["postgres", "risingwave"]) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_scalar_isin_literal_map_keys(con): mapping = ibis.literal({"a": 1, "b": 2}) a = ibis.literal("a") @@ -106,9 +108,8 @@ def test_scalar_isin_literal_map_keys(con): assert con.execute(false) == False # noqa: E712 -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_map_scalar_contains_key_scalar(con): mapping = ibis.literal({"a": 1, "b": 2}) a = ibis.literal("a") @@ -119,11 +120,7 @@ def test_map_scalar_contains_key_scalar(con): assert con.execute(false) == False # noqa: E712 -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_map_scalar_contains_key_column(backend, alltypes, df): value = {"1": "a", "3": "c"} mapping = ibis.literal(value) @@ -133,9 +130,8 @@ def test_map_scalar_contains_key_column(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_map_column_contains_key_scalar(backend, alltypes, df): expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])) series = df.apply(lambda row: {row["string_col"]: row["int_col"]}, axis=1) @@ -146,9 +142,8 @@ def test_map_column_contains_key_scalar(backend, alltypes, df): backend.assert_series_equal(result, series) -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_map_column_contains_key_column(alltypes): map_expr = ibis.map( ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col]) @@ -158,9 +153,8 @@ def test_map_column_contains_key_column(alltypes): assert result.all() -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_literal_map_merge(con): a = ibis.literal({"a": 0, "b": 2}) b = ibis.literal({"a": 1, "c": 3}) @@ -169,11 +163,7 @@ def test_literal_map_merge(con): assert con.execute(expr) == {"a": 1, "b": 2, "c": 3} -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_literal_map_getitem_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -186,11 +176,163 @@ def test_literal_map_getitem_broadcast(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", +keys = pytest.mark.parametrize( + "keys", + [ + pytest.param(["a", "b"], id="string"), + pytest.param( + [1, 2], + marks=[mark_notyet_postgres, mark_notyet_snowflake], + id="int", + ), + pytest.param( + [True, False], + marks=[mark_notyet_postgres, mark_notyet_snowflake], + id="bool", + ), + pytest.param( + [1.0, 2.0], + marks=[ + pytest.mark.notyet( + "clickhouse", reason="only supports str,int,bool,timestamp keys" + ), + mark_notyet_postgres, + mark_notyet_snowflake, + ], + id="float", + ), + pytest.param( + [ibis.timestamp("2021-01-01"), ibis.timestamp("2021-01-02")], + marks=[mark_notyet_postgres, mark_notyet_snowflake], + id="timestamp", + ), + pytest.param( + [ibis.date(1, 2, 3), ibis.date(4, 5, 6)], + marks=[ + pytest.mark.notyet( + "clickhouse", reason="only supports str,int,bool,timestamp keys" + ), + pytest.mark.notimpl( + ["pandas", "dask"], reason="DateFromYMD isn't implemented" + ), + mark_notyet_postgres, + mark_notyet_snowflake, + ], + id="date", + ), + pytest.param( + [[1, 2], [3, 4]], + marks=[ + pytest.mark.notyet( + "clickhouse", reason="only supports str,int,bool,timestamp keys" + ), + pytest.mark.notyet(["pandas", "dask"]), + mark_notyet_postgres, + mark_notyet_snowflake, + ], + id="array", + ), + pytest.param( + [ibis.struct(dict(a=1)), ibis.struct(dict(a=2))], + marks=[ + pytest.mark.notyet( + "clickhouse", reason="only supports str,int,bool,timestamp keys" + ), + pytest.mark.notyet(["pandas", "dask"]), + mark_notyet_postgres, + pytest.mark.notimpl("flink"), + mark_notyet_snowflake, + ], + id="struct", + ), + ], ) + + +values = pytest.mark.parametrize( + "values", + [ + pytest.param(["a", "b"], id="string"), + pytest.param( + [1, 2], + marks=[ + mark_notyet_postgres, + ], + id="int", + ), + pytest.param( + [True, False], + marks=[ + mark_notyet_postgres, + ], + id="bool", + ), + pytest.param( + [1.0, 2.0], + marks=[ + mark_notyet_postgres, + ], + id="float", + ), + pytest.param( + [ibis.timestamp("2021-01-01"), ibis.timestamp("2021-01-02")], + marks=[ + mark_notyet_postgres, + ], + id="timestamp", + ), + pytest.param( + [ibis.date(2021, 1, 1), ibis.date(2022, 2, 2)], + marks=[ + pytest.mark.notimpl( + ["pandas", "dask"], reason="DateFromYMD isn't implemented" + ), + mark_notyet_postgres, + ], + id="date", + ), + pytest.param( + [[1, 2], [3, 4]], + marks=[ + pytest.mark.notyet("clickhouse", reason="nested types can't be null"), + mark_notyet_postgres, + ], + id="array", + ), + pytest.param( + [ibis.struct(dict(a=1)), ibis.struct(dict(a=2))], + marks=[ + pytest.mark.notyet("clickhouse", reason="nested types can't be null"), + mark_notyet_postgres, + pytest.mark.notimpl("flink", reason="can't construct structs"), + ], + id="struct", + ), + ], +) + + +@values +@keys +@mark_notimpl_risingwave_hstore +def test_map_get_all_types(con, keys, values): + m = ibis.map(ibis.array(keys), ibis.array(values)) + for key, val in zip(keys, values): + if isinstance(val, ibis.Expr): + val = con.execute(val) + assert con.execute(m[key]) == val + + +@keys +@mark_notimpl_risingwave_hstore +def test_map_contains_all_types(con, keys): + a = ibis.array(keys) + m = ibis.map(a, a) + for key in keys: + assert con.execute(m.contains(key)) + + +@mark_notimpl_risingwave_hstore def test_literal_map_get_broadcast(backend, alltypes, df): value = {"1": "a", "2": "b"} @@ -218,20 +360,15 @@ def test_literal_map_get_broadcast(backend, alltypes, df): param(["a", "b"], ["1", "2"], id="int"), ], ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_map_construct_dict(con, keys, values): expr = ibis.map(keys, values) result = con.execute(expr.name("tmp")) assert result == dict(zip(keys, values)) -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres @pytest.mark.broken( ["flink"], raises=pa.lib.ArrowInvalid, @@ -245,37 +382,33 @@ def test_map_construct_array_column(con, alltypes, df): assert result.to_list() == expected.to_list() -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_map_get_with_compatible_value_smaller(con): value = ibis.literal({"A": 1000, "B": 2000}) expr = value.get("C", 3) assert con.execute(expr) == 3 -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_map_get_with_compatible_value_bigger(con): value = ibis.literal({"A": 1, "B": 2}) expr = value.get("C", 3000) assert con.execute(expr) == 3000 -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres def test_map_get_with_incompatible_value_different_kind(con): value = ibis.literal({"A": 1000, "B": 2000}) expr = value.get("C", 3.0) assert con.execute(expr) == 3.0 +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres @pytest.mark.parametrize("null_value", [None, ibis.NA]) -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) def test_map_get_with_null_on_not_nullable(con, null_value): map_type = dt.Map(dt.string, dt.Int16(nullable=False)) value = ibis.literal({"A": 1000, "B": 2000}).cast(map_type) @@ -288,11 +421,7 @@ def test_map_get_with_null_on_not_nullable(con, null_value): @pytest.mark.notyet( ["flink"], raises=Py4JJavaError, reason="Flink cannot handle typeless nulls" ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_map_get_with_null_on_null_type_with_null(con, null_value): value = ibis.literal({"A": None, "B": None}) expr = value.get("C", null_value) @@ -300,9 +429,8 @@ def test_map_get_with_null_on_null_type_with_null(con, null_value): assert pd.isna(result) -@pytest.mark.notyet( - ["postgres", "risingwave"], reason="only support maps of string -> string" -) +@mark_notimpl_risingwave_hstore +@mark_notyet_postgres @pytest.mark.notyet( ["flink"], raises=Py4JJavaError, reason="Flink cannot handle typeless nulls" ) @@ -317,11 +445,7 @@ def test_map_get_with_null_on_null_type_with_non_null(con): raises=exc.IbisError, reason="`tbl_properties` is required when creating table with schema", ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_map_create_table(con, temp_table): t = con.create_table( temp_table, @@ -335,11 +459,7 @@ def test_map_create_table(con, temp_table): raises=exc.OperationNotDefinedError, reason="No translation rule for ", ) -@pytest.mark.notimpl( - ["risingwave"], - raises=PsycoPg2InternalError, - reason="function hstore(character varying[], character varying[]) does not exist", -) +@mark_notimpl_risingwave_hstore def test_map_length(con): expr = ibis.literal(dict(a="A", b="B")).length() assert con.execute(expr) == 2 @@ -350,3 +470,10 @@ def test_map_keys_unnest(backend): expr = backend.map.kv.keys().unnest() result = expr.to_pandas() assert frozenset(result) == frozenset("abcdef") + + +@mark_notimpl_risingwave_hstore +def test_map_contains_null(con): + expr = ibis.map(["a"], ibis.literal([None], type="array")) + assert con.execute(expr.contains("a")) + assert not con.execute(expr.contains("b")) diff --git a/ibis/expr/operations/maps.py b/ibis/expr/operations/maps.py index 8323d880e6dc..edd4d0505c90 100644 --- a/ibis/expr/operations/maps.py +++ b/ibis/expr/operations/maps.py @@ -32,7 +32,7 @@ class MapLength(Unary): @public class MapGet(Value): arg: Value[dt.Map] - key: Value[dt.String | dt.Integer] + key: Value default: Value = None shape = rlz.shape_like("args") @@ -45,7 +45,7 @@ def dtype(self): @public class MapContains(Value): arg: Value[dt.Map] - key: Value[dt.String | dt.Integer] + key: Value shape = rlz.shape_like("args") dtype = dt.bool