Skip to content

Commit

Permalink
feat(snowflake): implement array map and array filter
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed May 13, 2024
1 parent 5488896 commit 6845631
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 36 deletions.
21 changes: 19 additions & 2 deletions ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ class SnowflakeCompiler(SQLGlotCompiler):
}

UNSUPPORTED_OPS = (
ops.ArrayMap,
ops.ArrayFilter,
ops.RowID,
ops.MultiQuantile,
ops.IntervalFromInteger,
Expand Down Expand Up @@ -634,3 +632,22 @@ def visit_Sample(
seed=None if seed is None else sge.convert(seed),
)
return sg.select(STAR).from_(sample)

def visit_ArrayMap(self, op, *, arg, param, body):
return self.f.transform(arg, sge.Lambda(this=body, expressions=[param]))

def visit_ArrayFilter(self, op, *, arg, param, body):
return self.f.filter(
arg,
sge.Lambda(
this=sg.and_(
body,
# necessary otherwise null values are treated as JSON nulls
# instead of SQL NULLs
self.cast(sg.to_identifier(param), op.dtype.value_type).is_(
sg.not_(NULL)
),
),
expressions=[param],
),
)
40 changes: 6 additions & 34 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,7 @@ def test_array_slice(backend, start, stop):

@builtin_array
@pytest.mark.notimpl(
[
"datafusion",
"flink",
"polars",
"snowflake",
"sqlite",
],
raises=com.OperationNotDefinedError,
["datafusion", "flink", "polars", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.broken(
["risingwave"],
Expand Down Expand Up @@ -465,6 +458,7 @@ def test_array_slice(backend, start, stop):
functools.partial(lambda x, y: x + y, y=1),
ibis._ + 1,
],
ids=["lambda", "partial", "deferred"],
)
@pytest.mark.broken(
["risingwave"],
Expand All @@ -485,14 +479,7 @@ def test_array_map(con, input, output, func):

@builtin_array
@pytest.mark.notimpl(
[
"dask",
"datafusion",
"flink",
"pandas",
"polars",
"snowflake",
],
["dask", "datafusion", "flink", "pandas", "polars"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
Expand Down Expand Up @@ -533,6 +520,7 @@ def test_array_map(con, input, output, func):
functools.partial(lambda x, y: x > y, y=1),
ibis._ > 1,
],
ids=["lambda", "partial", "deferred"],
)
def test_array_filter(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
Expand Down Expand Up @@ -1138,14 +1126,7 @@ def test_unnest_empty_array(con):

@builtin_array
@pytest.mark.notimpl(
[
"datafusion",
"flink",
"polars",
"snowflake",
"dask",
"pandas",
],
["datafusion", "flink", "polars", "dask", "pandas"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(["sqlite"], raises=com.UnsupportedBackendType)
Expand All @@ -1166,16 +1147,7 @@ def test_array_map_with_conflicting_names(backend, con):

@builtin_array
@pytest.mark.notimpl(
[
"datafusion",
"flink",
"polars",
"snowflake",
"sqlite",
"dask",
"pandas",
"sqlite",
],
["datafusion", "flink", "polars", "sqlite", "dask", "pandas", "sqlite"],
raises=com.OperationNotDefinedError,
)
def test_complex_array_map(con):
Expand Down

0 comments on commit 6845631

Please sign in to comment.