Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix output type for list.eval in certain cases #18570

Merged
merged 3 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ fn run_on_group_by_engine(
let state = ExecutionState::new();
let mut ac = phys_expr.evaluate_on_groups(&df_context, &groups, &state)?;
let out = match ac.agg_state() {
AggState::AggregatedScalar(_) | AggState::Literal(_) => {
AggState::AggregatedScalar(_) => {
let out = ac.aggregated();
out.as_list().into_series()
},
Expand Down
132 changes: 132 additions & 0 deletions py-polars/tests/unit/operations/namespaces/list/test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from __future__ import annotations

from typing import Any

import pytest

import polars as pl
from polars.exceptions import (
StructFieldNotFoundError,
)
from polars.testing import assert_frame_equal, assert_series_equal


def test_list_eval_dtype_inference() -> None:
grades = pl.DataFrame(
{
"student": ["bas", "laura", "tim", "jenny"],
"arithmetic": [10, 5, 6, 8],
"biology": [4, 6, 2, 7],
"geography": [8, 4, 9, 7],
}
)

rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16)

# the .list.first() would fail if .list.eval did not correctly infer the output type
assert grades.with_columns(
pl.concat_list(pl.all().exclude("student")).alias("all_grades")
).select(
pl.col("all_grades")
.list.eval(rank_pct, parallel=True)
.alias("grades_rank")
.list.first()
).to_series().to_list() == [
0.3333333333333333,
0.6666666666666666,
0.6666666666666666,
0.3333333333333333,
]


def test_list_eval_categorical() -> None:
df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)})
df = df.select(
pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null()))
)
assert_series_equal(
df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical))
)


def test_list_eval_type_coercion() -> None:
last_non_null_value = pl.element().fill_null(3).last()
df = pl.DataFrame({"array_cols": [[1, None]]})

assert df.select(
pl.col("array_cols")
.list.eval(last_non_null_value, parallel=False)
.alias("col_last")
).to_dict(as_series=False) == {"col_last": [[3]]}


def test_list_eval_all_null() -> None:
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns(
pl.col("bar").cast(pl.List(pl.String))
)

assert df.select(pl.col("bar").list.eval(pl.element())).to_dict(
as_series=False
) == {"bar": [None, None, None]}


def test_empty_eval_dtype_5546() -> None:
# https://github.com/pola-rs/polars/issues/5546
df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}])

dtype = df.dtypes[0]

assert (
df.limit(0).with_columns(
pl.col("a")
.list.eval(pl.element().filter(pl.first().struct.field("name") == 1))
.alias("a_filtered")
)
).dtypes == [dtype, dtype]


def test_list_eval_gather_every_13410() -> None:
df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]})
out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2)))
expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]})
assert_frame_equal(out, expected)


def test_list_eval_err_raise_15653() -> None:
df = pl.DataFrame({"foo": [[]]})
with pytest.raises(StructFieldNotFoundError):
df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz")))


def test_list_eval_type_cast_11188() -> None:
df = pl.DataFrame(
[
{"a": None},
],
schema={"a": pl.List(pl.Int64)},
)
assert df.select(
pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str")
).schema == {"a_str": pl.List(pl.String)}


@pytest.mark.parametrize(
"data",
[
{"a": [["0"], ["1"]]},
{"a": [["0", "1"], ["2", "3"]]},
{"a": [["0", "1"]]},
{"a": [["0"]]},
],
)
@pytest.mark.parametrize(
"expr",
[
pl.lit(""),
pl.format("test: {}", pl.element()),
],
)
def test_list_eval_list_output_18510(data: dict[str, Any], expr: pl.Expr) -> None:
df = pl.DataFrame(data)
result = df.select(pl.col("a").list.eval(pl.lit("")))
assert result.to_series().dtype == pl.List(pl.String)
106 changes: 1 addition & 105 deletions py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import (
ComputeError,
OutOfBoundsError,
SchemaError,
StructFieldNotFoundError,
)
from polars.exceptions import ComputeError, OutOfBoundsError, SchemaError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -342,44 +337,6 @@ def test_slice() -> None:
assert s.list.slice(-5, 2).to_list() == [[1], []]


def test_list_eval_dtype_inference() -> None:
grades = pl.DataFrame(
{
"student": ["bas", "laura", "tim", "jenny"],
"arithmetic": [10, 5, 6, 8],
"biology": [4, 6, 2, 7],
"geography": [8, 4, 9, 7],
}
)

rank_pct = pl.col("").rank(descending=True) / pl.col("").count().cast(pl.UInt16)

# the .list.first() would fail if .list.eval did not correctly infer the output type
assert grades.with_columns(
pl.concat_list(pl.all().exclude("student")).alias("all_grades")
).select(
pl.col("all_grades")
.list.eval(rank_pct, parallel=True)
.alias("grades_rank")
.list.first()
).to_series().to_list() == [
0.3333333333333333,
0.6666666666666666,
0.6666666666666666,
0.3333333333333333,
]


def test_list_eval_categorical() -> None:
df = pl.DataFrame({"test": [["a", None]]}, schema={"test": pl.List(pl.Categorical)})
df = df.select(
pl.col("test").list.eval(pl.element().filter(pl.element().is_not_null()))
)
assert_series_equal(
df.get_column("test"), pl.Series("test", [["a"]], dtype=pl.List(pl.Categorical))
)


def test_list_ternary_concat() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -423,17 +380,6 @@ def test_arr_contains_categorical() -> None:
assert result.to_dict(as_series=False) == expected


def test_list_eval_type_coercion() -> None:
last_non_null_value = pl.element().fill_null(3).last()
df = pl.DataFrame({"array_cols": [[1, None]]})

assert df.select(
pl.col("array_cols")
.list.eval(last_non_null_value, parallel=False)
.alias("col_last")
).to_dict(as_series=False) == {"col_last": [[3]]}


def test_list_slice() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -476,21 +422,6 @@ def test_list_sliced_get_5186() -> None:
assert_frame_equal(out1, out2)


def test_empty_eval_dtype_5546() -> None:
# https://github.com/pola-rs/polars/issues/5546
df = pl.DataFrame([{"a": [{"name": 1}, {"name": 2}]}])

dtype = df.dtypes[0]

assert (
df.limit(0).with_columns(
pl.col("a")
.list.eval(pl.element().filter(pl.first().struct.field("name") == 1))
.alias("a_filtered")
)
).dtypes == [dtype, dtype]


def test_list_amortized_apply_explode_5812() -> None:
s = pl.Series([None, [1, 3], [0, -3], [1, 2, 2]])
assert s.list.sum().to_list() == [None, 4, -3, 5]
Expand Down Expand Up @@ -548,16 +479,6 @@ def test_list_gather() -> None:
]


def test_list_eval_all_null() -> None:
df = pl.DataFrame({"foo": [1, 2, 3], "bar": [None, None, None]}).with_columns(
pl.col("bar").cast(pl.List(pl.String))
)

assert df.select(pl.col("bar").list.eval(pl.element())).to_dict(
as_series=False
) == {"bar": [None, None, None]}


def test_list_function_group_awareness() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -825,13 +746,6 @@ def test_list_get_logical_type() -> None:
assert_series_equal(out, expected)


def test_list_eval_gater_every_13410() -> None:
df = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]]})
out = df.with_columns(result=pl.col("a").list.eval(pl.element().gather_every(2)))
expected = pl.DataFrame({"a": [[1, 2, 3], [4, 5, 6]], "result": [[1, 3], [4, 6]]})
assert_frame_equal(out, expected)


def test_list_gather_every() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -896,24 +810,6 @@ def test_list_get_with_null() -> None:
assert_frame_equal(out, expected)


def test_list_eval_err_raise_15653() -> None:
df = pl.DataFrame({"foo": [[]]})
with pytest.raises(StructFieldNotFoundError):
df.with_columns(bar=pl.col("foo").list.eval(pl.element().struct.field("baz")))


def test_list_sum_bool_schema() -> None:
q = pl.LazyFrame({"x": [[True, True, False]]})
assert q.select(pl.col("x").list.sum()).collect_schema()["x"] == pl.UInt32


def test_list_eval_type_cast_11188() -> None:
df = pl.DataFrame(
[
{"a": None},
],
schema={"a": pl.List(pl.Int64)},
)
assert df.select(
pl.col("a").list.eval(pl.element().cast(pl.String)).alias("a_str")
).schema == {"a_str": pl.List(pl.String)}