Skip to content

Commit

Permalink
feat(snowflake): add more array operations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 18, 2023
1 parent e74328b commit 8d8bb70
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
30 changes: 25 additions & 5 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def _literal(t, op):
return sa.func.timestamp_from_parts(*args)
elif dtype.is_date():
return sa.func.date_from_parts(value.year, value.month, value.day)
elif dtype.is_array():
return sa.func.array_construct(*value)
return _postgres_literal(t, op)


Expand Down Expand Up @@ -88,6 +90,22 @@ def _extract_url_query(t, op):
return sa.func.nullif(sa.func.as_varchar(r), "")


def _array_slice(t, op):
arg = t.translate(op.arg)

if (start := op.start) is not None:
start = t.translate(start)
else:
start = 0

if (stop := op.stop) is not None:
stop = t.translate(stop)
else:
stop = sa.func.array_size(arg)

return sa.func.array_slice(t.translate(op.arg), start, stop)


_SF_POS_INF = sa.cast(sa.literal("Inf"), sa.FLOAT)
_SF_NEG_INF = -_SF_POS_INF
_SF_NAN = sa.cast(sa.literal("NaN"), sa.FLOAT)
Expand Down Expand Up @@ -158,6 +176,13 @@ def _extract_url_query(t, op):
),
# snowflake typeof only accepts VARIANT
ops.TypeOf: unary(lambda arg: sa.func.typeof(sa.cast(arg, VARIANT))),
ops.ArrayIndex: fixed_arity(sa.func.get, 2),
ops.ArrayLength: fixed_arity(sa.func.array_size, 1),
ops.ArrayConcat: fixed_arity(sa.func.array_cat, 2),
ops.ArrayColumn: lambda t, op: sa.func.array_construct(
*map(t.translate, op.cols)
),
ops.ArraySlice: _array_slice,
}
)

Expand All @@ -169,12 +194,7 @@ def _extract_url_query(t, op):
ops.NTile,
ops.NthValue,
# ibis.expr.operations.array
ops.ArrayColumn,
ops.ArrayConcat,
ops.ArrayIndex,
ops.ArrayLength,
ops.ArrayRepeat,
ops.ArraySlice,
ops.Unnest,
# ibis.expr.operations.maps
ops.MapKeys,
Expand Down
15 changes: 6 additions & 9 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
]


@pytest.mark.notimpl(["datafusion", "snowflake"])
@pytest.mark.notimpl(["datafusion"])
def test_array_column(backend, alltypes, df):
expr = ibis.array([alltypes['double_col'], alltypes['double_col']])
assert isinstance(expr, ir.ArrayColumn)
Expand All @@ -35,7 +35,6 @@ def test_array_column(backend, alltypes, df):
backend.assert_series_equal(result, expected, check_names=False)


@pytest.mark.notimpl(["snowflake"])
def test_array_scalar(con):
expr = ibis.array([1.0, 2.0, 3.0])
assert isinstance(expr, ir.ArrayScalar)
Expand All @@ -48,7 +47,7 @@ def test_array_scalar(con):
assert np.array_equal(result, expected)


@pytest.mark.notimpl(["snowflake", "polars", "datafusion"])
@pytest.mark.notimpl(["polars", "datafusion", "snowflake"])
def test_array_repeat(con):
expr = ibis.array([1.0, 2.0]) * 2

Expand All @@ -61,7 +60,7 @@ def test_array_repeat(con):


# Issues #2370
@pytest.mark.notimpl(["datafusion", "snowflake"])
@pytest.mark.notimpl(["datafusion"])
def test_array_concat(con):
left = ibis.literal([1, 2, 3])
right = ibis.literal([2, 1])
Expand All @@ -74,13 +73,12 @@ def test_array_concat(con):
assert np.array_equal(result, expected)


@pytest.mark.notimpl(["datafusion", "snowflake"])
@pytest.mark.notimpl(["datafusion"])
def test_array_length(con):
expr = ibis.literal([1, 2, 3]).length()
assert con.execute(expr.name("tmp")) == 3


@pytest.mark.notimpl(["snowflake"])
def test_list_literal(con):
arr = [1, 2, 3]
expr = ibis.literal(arr)
Expand All @@ -91,7 +89,6 @@ def test_list_literal(con):
assert np.array_equal(result, arr)


@pytest.mark.notimpl(["snowflake"])
def test_np_array_literal(con):
arr = np.array([1, 2, 3])
expr = ibis.literal(arr)
Expand All @@ -103,7 +100,7 @@ def test_np_array_literal(con):


@pytest.mark.parametrize("idx", range(3))
@pytest.mark.notimpl(["snowflake", "polars", "datafusion"])
@pytest.mark.notimpl(["polars", "datafusion"])
def test_array_index(con, idx):
arr = [1, 2, 3]
expr = ibis.literal(arr)
Expand Down Expand Up @@ -372,7 +369,7 @@ def test_unnest_default_name(con):
(-3, -1),
],
)
@pytest.mark.notimpl(["dask", "datafusion", "polars", "snowflake"])
@pytest.mark.notimpl(["dask", "datafusion", "polars"])
def test_array_slice(con, start, stop):
array_types = con.tables.array_types
expr = array_types.select(sliced=array_types.y[start:stop])
Expand Down

0 comments on commit 8d8bb70

Please sign in to comment.