diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index 38ea35f9ed94..f981752f6bf6 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -1,10 +1,11 @@ from __future__ import annotations import itertools +import json import numpy as np import sqlalchemy as sa -from snowflake.sqlalchemy.custom_types import VARIANT +from snowflake.sqlalchemy import VARIANT import ibis.expr.operations as ops from ibis.backends.base.sql.alchemy.registry import ( @@ -52,7 +53,7 @@ def _literal(t, op): return sa.func.array_construct(*value) elif dtype.is_map(): return sa.func.object_construct_keep_null( - *zip(itertools.chain.from_iterable(value.items())) + *itertools.chain.from_iterable(value.items()) ) return _postgres_literal(t, op) @@ -112,6 +113,17 @@ def _array_slice(t, op): return sa.func.array_slice(t.translate(op.arg), start, stop) +def _map(_, op): + if not ( + isinstance(keys := op.keys, ops.Literal) + and isinstance(values := op.values, ops.Literal) + ): + raise TypeError("Both keys and values of an `ibis.map` call must be literals") + + obj = dict(zip(keys.value, values.value)) + return sa.func.to_object(sa.func.parse_json(json.dumps(obj, separators=",:"))) + + _SF_POS_INF = sa.cast(sa.literal("Inf"), sa.FLOAT) _SF_NEG_INF = -_SF_POS_INF _SF_NAN = sa.cast(sa.literal("NaN"), sa.FLOAT) @@ -201,7 +213,7 @@ def _array_slice(t, op): ops.ArrayColumn: lambda t, op: sa.func.array_construct( *map(t.translate, op.cols) ), - ops.ArraySlice: _array_slice, + ops.Map: _map, } ) diff --git a/ibis/backends/tests/test_map.py b/ibis/backends/tests/test_map.py index 43b1add16321..cc8ce16e44b7 100644 --- a/ibis/backends/tests/test_map.py +++ b/ibis/backends/tests/test_map.py @@ -124,12 +124,16 @@ def test_literal_map_get_broadcast(backend, alltypes, df): backend.assert_series_equal(result, expected) -@pytest.mark.notyet(["snowflake"]) -def test_map_construction(con, alltypes, df): +def test_map_construct_dict(con): expr = ibis.map(['a', 'b'], [1, 2]) result = con.execute(expr.name('tmp')) assert result == {'a': 1, 'b': 2} + +@pytest.mark.notimpl( + ["snowflake"], reason="unclear how to implement two arrays -> object construction" +) +def test_map_construct_array_column(con, alltypes, df): expr = ibis.map(ibis.array([alltypes.string_col]), ibis.array([alltypes.int_col])) result = con.execute(expr) expected = df.apply(lambda row: {row['string_col']: row['int_col']}, axis=1) diff --git a/ibis/backends/tests/test_param.py b/ibis/backends/tests/test_param.py index 45632e31cbd4..b886ff96e166 100644 --- a/ibis/backends/tests/test_param.py +++ b/ibis/backends/tests/test_param.py @@ -58,7 +58,7 @@ def test_timestamp_accepts_date_literals(alltypes): assert expr.compile(params=params) is not None -@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "pyspark", "snowflake"]) +@pytest.mark.notimpl(["dask", "datafusion", "impala", "pandas", "pyspark"]) @pytest.mark.never( ["mysql", "sqlite", "mssql"], reason="backend will never implement array types" ) @@ -84,17 +84,7 @@ def test_scalar_param_struct(con): @pytest.mark.notimpl( - [ - "clickhouse", - "datafusion", - # TODO: duckdb maps are tricky because they are multimaps - "duckdb", - "impala", - "pyspark", - "snowflake", - "polars", - "trino", - ] + ["clickhouse", "datafusion", "duckdb", "impala", "pyspark", "polars", "trino"] ) @pytest.mark.never( ["mysql", "sqlite", "mssql"],