Skip to content

Commit

Permalink
fix(polars): ensure that pivot_longer works with more than one column
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jul 21, 2023
1 parent cb1956f commit 822c912
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 20 deletions.
33 changes: 18 additions & 15 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,18 @@ def selection(op, **kw):
lf = lf.filter(predicate)

selections = []
unnests = []
for arg in op.selections:
if isinstance(arg, ops.TableNode):
for name in arg.schema.names:
column = ops.TableColumn(table=arg, name=name)
selections.append(translate(column, **kw))
elif (
isinstance(arg, ops.Alias) and isinstance(unnest := arg.arg, ops.Unnest)
) or isinstance(unnest := arg, ops.Unnest):
name = arg.name
unnests.append(name)
selections.append(translate(unnest.arg, **kw).alias(name))
elif isinstance(arg, ops.Value):
selections.append(translate(arg, **kw))
else:
Expand All @@ -204,6 +211,9 @@ def selection(op, **kw):
if selections:
lf = lf.select(selections)

if unnests:
lf = lf.explode(*unnests)

if op.sort_keys:
by = [key.name for key in op.sort_keys]
descending = [key.descending for key in op.sort_keys]
Expand Down Expand Up @@ -698,10 +708,12 @@ def struct_column(op, **kw):
def reduction(op, **kw):
arg = translate(op.arg, **kw)
agg = _reductions[type(op)]
filt = arg.is_not_null()
if (where := op.where) is not None:
arg = arg.filter(translate(where, **kw))
filt &= translate(where, **kw)
arg = arg.filter(filt)
method = getattr(arg, agg)
return method()
return method().cast(dtype_to_polars(op.output_dtype))


@translate.register(ops.Mode)
Expand Down Expand Up @@ -739,10 +751,10 @@ def distinct(op, **kw):
def count_star(op, **kw):
if (where := op.where) is not None:
condition = translate(where, **kw)
# TODO: clean up the casts and use schema.apply_to in the backend's
# execute method
return condition.filter(condition).count().cast(pl.Int64)
return pl.count().cast(pl.Int64)
result = condition.sum()
else:
result = pl.count()
return result.cast(dtype_to_polars(op.output_dtype))


@translate.register(ops.TimestampNow)
Expand Down Expand Up @@ -891,15 +903,6 @@ def array_collect(op, **kw):
return arg.list() # pragma: no cover


@translate.register(ops.Unnest)
def unnest(op, **kw):
arg = translate(op.arg, **kw)
try:
return arg.arr.explode()
except AttributeError:
return arg.explode()


_date_methods = {
ops.ExtractDay: "day",
ops.ExtractMonth: "month",
Expand Down
11 changes: 7 additions & 4 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,6 @@ def test_unnest_simple(backend):


@unnest
@pytest.mark.notimpl("polars", raises=PolarsComputeError, reason="Series shape: (6,)")
@pytest.mark.notimpl("dask", raises=com.OperationNotDefinedError)
def test_unnest_complex(backend):
array_types = backend.array_types
Expand Down Expand Up @@ -370,7 +369,9 @@ def test_unnest_complex(backend):
reason="clickhouse throws away nulls in groupArray",
raises=AssertionError,
)
@pytest.mark.notimpl("polars", raises=PolarsComputeError, reason="Series shape: (6,)")
@pytest.mark.notyet(
"polars", raises=AssertionError, reason="polars implode returns the wrong shape"
)
@pytest.mark.notimpl(["dask"], raises=ValueError)
def test_unnest_idempotent(backend):
array_types = backend.array_types
Expand All @@ -391,8 +392,10 @@ def test_unnest_idempotent(backend):


@unnest
@pytest.mark.notimpl("polars", raises=PolarsComputeError, reason="Series shape: (6,)")
@pytest.mark.notimpl(["dask"], raises=ValueError)
@pytest.mark.notimpl(
"polars", raises=TypeError, reason="polars implode returns the wrong shape"
)
@pytest.mark.notimpl("dask", raises=ValueError)
def test_unnest_no_nulls(backend):
array_types = backend.array_types
df = array_types.execute()
Expand Down
1 change: 0 additions & 1 deletion ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,6 @@ def query(t, group_cols):


@pytest.mark.notimpl(["dask", "pandas", "oracle"], raises=com.OperationNotDefinedError)
@pytest.mark.notyet(["polars"], reason="polars doesn't expand > 1 explode")
@pytest.mark.notimpl(["druid"], raises=AssertionError)
@pytest.mark.notyet(
["bigquery"],
Expand Down

0 comments on commit 822c912

Please sign in to comment.