Skip to content

Commit

Permalink
fix: gather.get schema (#13679)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 12, 2024
1 parent 31306e5 commit 6f44725
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
13 changes: 12 additions & 1 deletion crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,18 @@ impl AExpr {
Ok(field)
},
Sort { expr, .. } => arena.get(*expr).to_field(schema, ctxt, arena),
Gather { expr, .. } => arena.get(*expr).to_field(schema, ctxt, arena),
Gather {
expr,
returns_scalar,
..
} => {
let ctxt = if *returns_scalar {
Context::Default
} else {
ctxt
};
arena.get(*expr).to_field(schema, ctxt, arena)
},
SortBy { expr, .. } => arena.get(*expr).to_field(schema, ctxt, arena),
Filter { input, .. } => arena.get(*input).to_field(schema, ctxt, arena),
Agg(agg) => {
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/operations/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,26 @@ def test_negative_index() -> None:
assert df.group_by(pl.col("a") % 2).agg(b=pl.col("a").gather([0, -1])).sort(
"a"
).to_dict(as_series=False) == {"a": [0, 1], "b": [[2, 6], [1, 5]]}


def test_gather_agg_schema() -> None:
df = pl.DataFrame(
{
"group": [
"one",
"one",
"one",
"two",
"two",
"two",
],
"value": [1, 98, 2, 3, 99, 4],
}
)
assert (
df.lazy()
.group_by("group", maintain_order=True)
.agg(pl.col("value").get(1))
.schema["value"]
== pl.Int64
)

0 comments on commit 6f44725

Please sign in to comment.