diff --git a/ibis/backends/snowflake/registry.py b/ibis/backends/snowflake/registry.py index fe219b44e9d0..6261370537bf 100644 --- a/ibis/backends/snowflake/registry.py +++ b/ibis/backends/snowflake/registry.py @@ -46,6 +46,8 @@ def _literal(t, op): return sa.func.timestamp_from_parts(*args) elif dtype.is_date(): return sa.func.date_from_parts(value.year, value.month, value.day) + elif dtype.is_array(): + return sa.func.array_construct(*value) return _postgres_literal(t, op) @@ -88,6 +90,22 @@ def _extract_url_query(t, op): return sa.func.nullif(sa.func.as_varchar(r), "") +def _array_slice(t, op): + arg = t.translate(op.arg) + + if (start := op.start) is not None: + start = t.translate(start) + else: + start = 0 + + if (stop := op.stop) is not None: + stop = t.translate(stop) + else: + stop = sa.func.array_size(arg) + + return sa.func.array_slice(t.translate(op.arg), start, stop) + + _SF_POS_INF = sa.cast(sa.literal("Inf"), sa.FLOAT) _SF_NEG_INF = -_SF_POS_INF _SF_NAN = sa.cast(sa.literal("NaN"), sa.FLOAT) @@ -158,6 +176,13 @@ def _extract_url_query(t, op): ), # snowflake typeof only accepts VARIANT ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.cast(arg, VARIANT))), + ops.ArrayIndex: fixed_arity(sa.func.get, 2), + ops.ArrayLength: fixed_arity(sa.func.array_size, 1), + ops.ArrayConcat: fixed_arity(sa.func.array_cat, 2), + ops.ArrayColumn: lambda t, op: sa.func.array_construct( + *map(t.translate, op.cols) + ), + ops.ArraySlice: _array_slice, } ) @@ -169,12 +194,7 @@ def _extract_url_query(t, op): ops.NTile, ops.NthValue, # ibis.expr.operations.array - ops.ArrayColumn, - ops.ArrayConcat, - ops.ArrayIndex, - ops.ArrayLength, ops.ArrayRepeat, - ops.ArraySlice, ops.Unnest, # ibis.expr.operations.maps ops.MapKeys, diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 9e029638b1f5..33e2901b553f 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -22,7 +22,7 @@ ] -@pytest.mark.notimpl(["datafusion", "snowflake"]) +@pytest.mark.notimpl(["datafusion"]) def test_array_column(backend, alltypes, df): expr = ibis.array([alltypes['double_col'], alltypes['double_col']]) assert isinstance(expr, ir.ArrayColumn) @@ -35,7 +35,6 @@ def test_array_column(backend, alltypes, df): backend.assert_series_equal(result, expected, check_names=False) -@pytest.mark.notimpl(["snowflake"]) def test_array_scalar(con): expr = ibis.array([1.0, 2.0, 3.0]) assert isinstance(expr, ir.ArrayScalar) @@ -48,7 +47,7 @@ def test_array_scalar(con): assert np.array_equal(result, expected) -@pytest.mark.notimpl(["snowflake", "polars", "datafusion"]) +@pytest.mark.notimpl(["polars", "datafusion", "snowflake"]) def test_array_repeat(con): expr = ibis.array([1.0, 2.0]) * 2 @@ -61,7 +60,7 @@ def test_array_repeat(con): # Issues #2370 -@pytest.mark.notimpl(["datafusion", "snowflake"]) +@pytest.mark.notimpl(["datafusion"]) def test_array_concat(con): left = ibis.literal([1, 2, 3]) right = ibis.literal([2, 1]) @@ -74,13 +73,12 @@ def test_array_concat(con): assert np.array_equal(result, expected) -@pytest.mark.notimpl(["datafusion", "snowflake"]) +@pytest.mark.notimpl(["datafusion"]) def test_array_length(con): expr = ibis.literal([1, 2, 3]).length() assert con.execute(expr.name("tmp")) == 3 -@pytest.mark.notimpl(["snowflake"]) def test_list_literal(con): arr = [1, 2, 3] expr = ibis.literal(arr) @@ -91,7 +89,6 @@ def test_list_literal(con): assert np.array_equal(result, arr) -@pytest.mark.notimpl(["snowflake"]) def test_np_array_literal(con): arr = np.array([1, 2, 3]) expr = ibis.literal(arr) @@ -103,7 +100,7 @@ def test_np_array_literal(con): @pytest.mark.parametrize("idx", range(3)) -@pytest.mark.notimpl(["snowflake", "polars", "datafusion"]) +@pytest.mark.notimpl(["polars", "datafusion"]) def test_array_index(con, idx): arr = [1, 2, 3] expr = ibis.literal(arr) @@ -372,7 +369,7 @@ def test_unnest_default_name(con): (-3, -1), ], ) -@pytest.mark.notimpl(["dask", "datafusion", "polars", "snowflake"]) +@pytest.mark.notimpl(["dask", "datafusion", "polars"]) def test_array_slice(con, start, stop): array_types = con.tables.array_types expr = array_types.select(sliced=array_types.y[start:stop])