Skip to content

Commit

Permalink
feat(snowflake): implement array map and array filter (#9178)
Browse files Browse the repository at this point in the history
Adds array map and filter implementations to the Snowflake backend.
Thanks again @krzysztof-kwitt for the idea!

The main limitation for the Snowflake versions of these APIs is that
your lambda cannot reference any columns outside the scope of the
lambda, which means the scope is limited to the lambda parameter and
constants.
  • Loading branch information
cpcloud authored May 13, 2024
1 parent 1ba4c32 commit 9b42751
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 36 deletions.
21 changes: 19 additions & 2 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
}

UNSUPPORTED_OPS = (
ops.ArrayMap,
ops.ArrayFilter,
ops.RowID,
ops.MultiQuantile,
ops.IntervalFromInteger,
Expand Down Expand Up @@ -639,3 +637,22 @@ def visit_Sample(
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.filter(
arg,
sge.Lambda(
this=sg.and_(
body,
# necessary otherwise null values are treated as JSON nulls
# instead of SQL NULLs
self.cast(sg.to_identifier(param), op.dtype.value_type).is_(
sg.not_(NULL)
),
),
expressions=[param],
),
)
40 changes: 6 additions & 34 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,7 @@ def test_array_slice(backend, start, stop):

@builtin_array
@pytest.mark.notimpl(
[
"datafusion",
"flink",
"polars",
"snowflake",
"sqlite",
],
raises=com.OperationNotDefinedError,
["datafusion", "flink", "polars", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.broken(
["risingwave"],
Expand Down Expand Up @@ -465,6 +458,7 @@ def test_array_slice(backend, start, stop):
functools.partial(lambda x, y: x + y, y=1),
ibis._ + 1,
],
ids=["lambda", "partial", "deferred"],
)
@pytest.mark.broken(
["risingwave"],
Expand All @@ -485,14 +479,7 @@ def test_array_map(con, input, output, func):

@builtin_array
@pytest.mark.notimpl(
[
"dask",
"datafusion",
"flink",
"pandas",
"polars",
"snowflake",
],
["dask", "datafusion", "flink", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand Down Expand Up @@ -533,6 +520,7 @@ def test_array_map(con, input, output, func):
functools.partial(lambda x, y: x > y, y=1),
ibis._ > 1,
],
ids=["lambda", "partial", "deferred"],
)
def test_array_filter(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
Expand Down Expand Up @@ -1138,14 +1126,7 @@ def test_unnest_empty_array(con):

@builtin_array
@pytest.mark.notimpl(
[
"datafusion",
"flink",
"polars",
"snowflake",
"dask",
"pandas",
],
["datafusion", "flink", "polars", "dask", "pandas"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["sqlite"], raises=com.UnsupportedBackendType)
Expand All @@ -1166,16 +1147,7 @@ def test_array_map_with_conflicting_names(backend, con):

@builtin_array
@pytest.mark.notimpl(
[
"datafusion",
"flink",
"polars",
"snowflake",
"sqlite",
"dask",
"pandas",
"sqlite",
],
["datafusion", "flink", "polars", "sqlite", "dask", "pandas", "sqlite"],
raises=com.OperationNotDefinedError,
)
def test_complex_array_map(con):
Expand Down

0 comments on commit 9b42751

Please sign in to comment.