Skip to content

Commit

Permalink
fix(pyspark): ensure that the output of zip matches the expected ibis…
Browse files Browse the repository at this point in the history
… schema (#9052)

Fix PySpark zip implementation to ensure that its output matches the
schema expected by Ibis. Fixes #9049.

---------

Co-authored-by: Gil Forsyth <gforsyth@users.noreply.github.com>
  • Loading branch information
cpcloud and gforsyth authored Apr 25, 2024
1 parent 92ba1c2 commit be9d5da
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def visit_MapGet(self, op, *, arg, key, default):
return self.if_(self.f.map_contains_key(arg, key), arg[key], default)

def visit_ArrayZip(self, op, *, arg):
return self.f.arrays_zip(*arg)
return self.cast(self.f.arrays_zip(*arg), op.dtype)

def visit_ArrayMap(self, op, *, arg, body, param):
param = sge.Identifier(this=param)
Expand Down
20 changes: 20 additions & 0 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,3 +1401,23 @@ def test_array_literal_with_exprs(con, input, expected):
assert expr.op().shape == ds.scalar
result = list(con.execute(expr))
assert result == expected


@pytest.mark.notimpl(
["datafusion", "postgres", "pandas", "polars", "risingwave", "dask", "flink"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.broken(
["trino"],
raises=TrinoUserError,
reason="sqlglot generates code that assumes there's only at most two fields to unpack from a struct",
)
def test_zip_unnest_lift(con):
data = pd.DataFrame(dict(array1=[[1, 2, 3]], array2=[[4, 5, 6]]))
t = ibis.memtable(data)
zipped = t.mutate(zipped=t.array1.zip(t.array2))
unnested = zipped.mutate(unnest=zipped.zipped.unnest())
lifted = unnested.unnest.lift()
result = con.execute(lifted)
expected = pd.DataFrame({"f1": [1, 2, 3], "f2": [4, 5, 6]})
tm.assert_frame_equal(result, expected)

0 comments on commit be9d5da

Please sign in to comment.