From 9b4275105ce7a8ad7cad798ea8a79639cea3076f Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 13 May 2024 10:13:46 -0400 Subject: [PATCH] feat(snowflake): implement array map and array filter (#9178) Adds array map and filter implementations to the Snowflake backend. Thanks again @krzysztof-kwitt for the idea! The main limitation for the Snowflake versions of these APIs is that your lambda cannot reference any columns outside the scope of the lambda, which means the scope is limited to the lambda parameter and constants. --- ibis/backends/snowflake/compiler.py | 21 +++++++++++++-- ibis/backends/tests/test_array.py | 40 +++++------------------------ 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index ff80e20e5116..33d4ad325e2e 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -48,8 +48,6 @@ class SnowflakeCompiler(SQLGlotCompiler): } UNSUPPORTED_OPS = ( - ops.ArrayMap, - ops.ArrayFilter, ops.RowID, ops.MultiQuantile, ops.IntervalFromInteger, @@ -639,3 +637,22 @@ def visit_Sample( seed=None if seed is None else sge.convert(seed), ) return sg.select(STAR).from_(sample) + + def visit_ArrayMap(self, op, *, arg, param, body): + return self.f.transform(arg, sge.Lambda(this=body, expressions=[param])) + + def visit_ArrayFilter(self, op, *, arg, param, body): + return self.f.filter( + arg, + sge.Lambda( + this=sg.and_( + body, + # necessary otherwise null values are treated as JSON nulls + # instead of SQL NULLs + self.cast(sg.to_identifier(param), op.dtype.value_type).is_( + sg.not_(NULL) + ), + ), + expressions=[param], + ), + ) diff --git a/ibis/backends/tests/test_array.py b/ibis/backends/tests/test_array.py index 7340540f7269..d78b9bf326e1 100644 --- a/ibis/backends/tests/test_array.py +++ b/ibis/backends/tests/test_array.py @@ -418,14 +418,7 @@ def test_array_slice(backend, start, stop): @builtin_array @pytest.mark.notimpl( - [ - "datafusion", - "flink", - "polars", - "snowflake", - "sqlite", - ], - raises=com.OperationNotDefinedError, + ["datafusion", "flink", "polars", "sqlite"], raises=com.OperationNotDefinedError ) @pytest.mark.broken( ["risingwave"], @@ -465,6 +458,7 @@ def test_array_slice(backend, start, stop): functools.partial(lambda x, y: x + y, y=1), ibis._ + 1, ], + ids=["lambda", "partial", "deferred"], ) @pytest.mark.broken( ["risingwave"], @@ -485,14 +479,7 @@ def test_array_map(con, input, output, func): @builtin_array @pytest.mark.notimpl( - [ - "dask", - "datafusion", - "flink", - "pandas", - "polars", - "snowflake", - ], + ["dask", "datafusion", "flink", "pandas", "polars"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl( @@ -533,6 +520,7 @@ def test_array_map(con, input, output, func): functools.partial(lambda x, y: x > y, y=1), ibis._ > 1, ], + ids=["lambda", "partial", "deferred"], ) def test_array_filter(con, input, output, predicate): t = ibis.memtable(input, schema=ibis.schema(dict(a="!array"))) @@ -1138,14 +1126,7 @@ def test_unnest_empty_array(con): @builtin_array @pytest.mark.notimpl( - [ - "datafusion", - "flink", - "polars", - "snowflake", - "dask", - "pandas", - ], + ["datafusion", "flink", "polars", "dask", "pandas"], raises=com.OperationNotDefinedError, ) @pytest.mark.notimpl(["sqlite"], raises=com.UnsupportedBackendType) @@ -1166,16 +1147,7 @@ def test_array_map_with_conflicting_names(backend, con): @builtin_array @pytest.mark.notimpl( - [ - "datafusion", - "flink", - "polars", - "snowflake", - "sqlite", - "dask", - "pandas", - "sqlite", - ], + ["datafusion", "flink", "polars", "sqlite", "dask", "pandas", "sqlite"], raises=com.OperationNotDefinedError, ) def test_complex_array_map(con):