Skip to content

Commit

Permalink
feat(pyspark): add ArrayFilter operation
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko authored and cpcloud committed Mar 26, 2023
1 parent 315b5e7 commit 2b1301e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
14 changes: 14 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,20 @@ def compile_array_collect(t, op, **kwargs):
return F.collect_list(src_column)


@compiles(ops.Argument)
def compile_argument(t, op, arg_columns, **kwargs):
return arg_columns[op.name]


@compiles(ops.ArrayFilter)
def compile_array_filter(t, op, **kwargs):
src_column = t.translate(op.arg, **kwargs)
return F.filter(
src_column,
lambda x: t.translate(op.result, arg_columns={op.parameter: x}, **kwargs),
)


# --------------------------- Null Operations -----------------------------


Expand Down
14 changes: 14 additions & 0 deletions ibis/backends/pyspark/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,17 @@ def test_array_collect(client):
.rename(columns={'array_int': 'collected'})
)
tm.assert_frame_equal(result, expected)


def test_array_filter(client):
table = client.table('array_table')
expr = table.select(
table.array_int.filter(lambda item: item != 3).name('array_int')
)
result = expr.execute()
df = table.compile().toPandas()
df['array_int'] = df['array_int'].apply(
lambda ar: [item for item in ar if item != 3]
)
expected = df[['array_int']]
tm.assert_frame_equal(result, expected)
1 change: 0 additions & 1 deletion ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ def test_array_map(backend, con):
"mssql",
"polars",
"postgres",
"pyspark",
"snowflake",
],
raises=com.OperationNotDefinedError,
Expand Down

0 comments on commit 2b1301e

Please sign in to comment.