Skip to content

Commit

Permalink
fix(polars): use flatten API for ArrayFlatten implementation to a…
Browse files Browse the repository at this point in the history
…void large string upcast (#9997)

Closes #9995.
  • Loading branch information
cpcloud authored Sep 3, 2024
1 parent e0f54c9 commit 7a6af8d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
9 changes: 8 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,14 @@ def array_collect(op, in_group_by=False, **kw):

@translate.register(ops.ArrayFlatten)
def array_flatten(op, **kw):
return pl.concat_list(translate(op.arg, **kw))
result = translate(op.arg, **kw)
return (
pl.when(result.is_null())
.then(None)
.when(result.list.len() == 0)
.then([])
.otherwise(result.flatten())
)


_date_methods = {
Expand Down
15 changes: 15 additions & 0 deletions ibis/backends/polars/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import pytest

import ibis
from ibis.backends.tests.errors import PolarsSQLInterfaceError
from ibis.util import gen_name

pd = pytest.importorskip("pandas")
tm = pytest.importorskip("pandas.testing")


def test_cannot_run_sql_after_drop(con):
t = con.table("functional_alltypes")
Expand All @@ -22,3 +26,14 @@ def test_cannot_run_sql_after_drop(con):
con.drop_table(name)
with pytest.raises(PolarsSQLInterfaceError):
con.sql(sql)


def test_array_flatten(con):
data = {"id": range(3), "happy": [[["abc"]], [["bcd"]], [["def"]]]}
t = ibis.memtable(data)
expr = t.select("id", flat=t.happy.flatten()).order_by("id")
result = con.to_pyarrow(expr)
expected = pd.DataFrame(
{"id": data["id"], "flat": [row[0] for row in data["happy"]]}
)
tm.assert_frame_equal(result.to_pandas(), expected)
20 changes: 8 additions & 12 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,25 +1037,21 @@ def flatten_data():
reason="Arrays are never nullable",
raises=AssertionError,
),
pytest.mark.notimpl(
["polars"],
raises=TypeError,
reason="comparison of nested arrays doesn't work in pandas testing module",
),
],
),
],
)
@pytest.mark.notimpl(["flink"], raises=com.OperationNotDefinedError)
def test_array_flatten(backend, flatten_data, column, expected):
data = flatten_data[column]
t = ibis.memtable({column: data["data"]}, schema={column: data["type"]})
expr = t[column].flatten()
result = backend.connection.execute(expr)
backend.assert_series_equal(
result.sort_values().reset_index(drop=True),
expected.sort_values().reset_index(drop=True),
check_names=False,
ids = range(len(data["data"]))
t = ibis.memtable(
{column: data["data"], "id": ids}, schema={column: data["type"], "id": "int64"}
)
expr = t.select("id", flat=t[column].flatten()).order_by("id")
result = backend.connection.to_pandas(expr)
tm.assert_frame_equal(
result, expected.rename("flat").to_frame().assign(id=ids)[["id", "flat"]]
)


Expand Down

0 comments on commit 7a6af8d

Please sign in to comment.