Skip to content

Commit

Permalink
feat(rust, python): add sort for struct dtype (#7021)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 19, 2023
1 parent e8ae356 commit 7a96006
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 8 deletions.
20 changes: 20 additions & 0 deletions polars/polars-core/src/series/implementations/struct_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,4 +362,24 @@ impl SeriesTrait for SeriesWrap<StructChunked> {
fn as_any(&self) -> &dyn Any {
&self.0
}

fn sort_with(&self, options: SortOptions) -> Series {
let df = self.0.clone().unnest();

let desc = if options.descending {
vec![true; df.width()]
} else {
vec![false; df.width()]
};
let out = df
.sort_impl(
df.columns.clone(),
desc,
options.nulls_last,
None,
options.multithreaded,
)
.unwrap();
StructChunked::new_unchecked(self.name(), &out.columns).into_series()
}
}
9 changes: 9 additions & 0 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,12 @@ def test_struct_concat_self_no_rechunk() -> None:
out = pl.concat([df, df], rechunk=False)
assert out.dtypes == [pl.Struct([pl.Field("a", pl.Int64)])]
assert out.to_dict(False) == {"A": [{"a": 1}, {"a": 1}]}


def test_sort_structs() -> None:
assert pl.DataFrame(
{"sex": ["male", "female", "female"], "age": [22, 38, 26]}
).select(pl.struct(["sex", "age"]).sort()).unnest("sex").to_dict(False) == {
"sex": ["female", "female", "male"],
"age": [26, 38, 22],
}
8 changes: 0 additions & 8 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,6 @@ def test_string_numeric_comp_err() -> None:
pl.DataFrame({"a": [1.1, 21, 31, 21, 51, 61, 71, 81]}).select(pl.col("a") < "9")


def test_panic_exception() -> None:
with pytest.raises(
pl.PanicException,
match=r"""this operation is not implemented/valid for this dtype: .*""",
):
pl.struct(pl.Series("a", [1, 2, 3]), eager=True).sort()


@typing.no_type_check
def test_join_lazy_on_df() -> None:
df_left = pl.DataFrame(
Expand Down

0 comments on commit 7a96006

Please sign in to comment.