Skip to content

Commit

Permalink
feat(snowflake): implement ops.ArrayRepeat
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Aug 5, 2023
1 parent 6f7e13d commit a93cbd6
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
9 changes: 8 additions & 1 deletion ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import sqlalchemy as sa
import sqlalchemy.types as sat
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.sqlalchemy import ARRAY, OBJECT, URL
from snowflake.sqlalchemy import ARRAY, DOUBLE, OBJECT, URL
from sqlalchemy.ext.compiler import compiles

import ibis
Expand Down Expand Up @@ -96,6 +96,13 @@ class SnowflakeCompiler(AlchemyCompiler):
"returns": ARRAY,
"source": """return array.sort();""",
},
"ibis_udfs.public.array_repeat": {
# Integer inputs are not allowed because JavaScript only supports
# doubles
"inputs": {"value": ARRAY, "count": DOUBLE},
"returns": ARRAY,
"source": """return Array(count).fill(value).flat();""",
},
}


Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def _map_get(t, op):
ops.TableColumn: _table_column,
ops.Levenshtein: fixed_arity(sa.func.editdistance, 2),
ops.ArraySort: unary(sa.func.ibis_udfs.public.array_sort),
ops.ArrayRepeat: fixed_arity(sa.func.ibis_udfs.public.array_repeat, 2),
}
)

Expand All @@ -471,8 +472,6 @@ def _map_get(t, op):
ops.CumulativeAny,
ops.CumulativeOp,
ops.NTile,
# ibis.expr.operations.array
ops.ArrayRepeat,
# ibis.expr.operations.reductions
ops.MultiQuantile,
# ibis.expr.operations.strings
Expand Down
4 changes: 1 addition & 3 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def test_array_scalar(con, backend):
assert con.execute(expr.typeof()) == ARRAY_BACKEND_TYPES[backend_name]


@pytest.mark.notimpl(
["polars", "datafusion", "snowflake"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["polars", "datafusion"], raises=com.OperationNotDefinedError)
def test_array_repeat(con):
expr = ibis.array([1.0, 2.0]) * 2

Expand Down

0 comments on commit a93cbd6

Please sign in to comment.