From e3a8484bcf2dc53c46f866408e750ffc564b1576 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Wed, 16 Nov 2022 09:48:31 +0100 Subject: [PATCH] fix(rust, python): fix sort_by expression if groups already aggregated --- .../src/physical_plan/expressions/sortby.rs | 230 ++++++++++-------- py-polars/polars/testing/_private.py | 29 +++ py-polars/tests/unit/test_sort.py | 19 ++ 3 files changed, 183 insertions(+), 95 deletions(-) diff --git a/polars/polars-lazy/src/physical_plan/expressions/sortby.rs b/polars/polars-lazy/src/physical_plan/expressions/sortby.rs index 15c70b25036d5..1438f42b08cee 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -97,113 +97,153 @@ impl PhysicalExpr for SortByExpr { ) -> PolarsResult> { let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?; - let reverse = prepare_reverse(&self.reverse, self.by.len()); - - let (groups, ordered_by_group_operation) = if self.by.len() == 1 { + // the groups of the lhs of the expressions do not match the series values + // we must take the slower path. + if !matches!(ac_in.update_groups, UpdateGroups::No) { + if self.by.len() > 1 { + let msg = "This expression is not yet supported for more than two sort columns. \ + Consider opeingin a feature request."; + return Err(expression_err!(msg, self.expr, ComputeError)); + } let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?; - let sort_by_s = ac_sort_by.flat_naive().into_owned(); + let sort_by = ac_sort_by.aggregated(); + let mut sort_by = sort_by.list().unwrap().clone(); + let s = ac_in.aggregated(); + let mut s = s.list().unwrap().clone(); - let ordered_by_group_operation = matches!( - ac_sort_by.update_groups, - UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen - ); - let groups = ac_sort_by.groups(); - - let groups = groups - .par_iter() - .map(|indicator| { - let new_idx = match indicator { - GroupsIndicator::Idx((_, idx)) => { - // Safety: - // Group tuples are always in bounds - let group = unsafe { - sort_by_s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) - }; - - let sorted_idx = group.argsort(SortOptions { - descending: reverse[0], + let descending = self.reverse[0]; + let mut ca: ListChunked = s + .par_iter_indexed() + .zip(sort_by.par_iter_indexed()) + .map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) { + (Some(s), Some(s_sort_by)) => { + if s.len() != s_sort_by.len() { + None + } else { + let idx = s_sort_by.argsort(SortOptions { + descending, ..Default::default() }); - map_sorted_indices_to_group_idx(&sorted_idx, idx) + Some(unsafe { s.take_unchecked(&idx).unwrap() }) } - GroupsIndicator::Slice([first, len]) => { - let group = sort_by_s.slice(first as i64, len as usize); - let sorted_idx = group.argsort(SortOptions { - descending: reverse[0], - ..Default::default() - }); - map_sorted_indices_to_group_slice(&sorted_idx, first) - } - }; - - (new_idx[0], new_idx) + } + _ => None, }) .collect(); - - (GroupsProxy::Idx(groups), ordered_by_group_operation) + ca.rename(s.name()); + let s = ca.into_series(); + ac_in.with_series(s, true); + Ok(ac_in) } else { - let mut ac_sort_by = self - .by - .iter() - .map(|e| e.evaluate_on_groups(df, groups, state)) - .collect::>>()?; - let sort_by_s = ac_sort_by - .iter() - .map(|s| s.flat_naive().into_owned()) - .collect::>(); - - let ordered_by_group_operation = matches!( - ac_sort_by[0].update_groups, - UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen - ); - let groups = ac_sort_by[0].groups(); - - let groups = groups - .par_iter() - .map(|indicator| { - let new_idx = match indicator { - GroupsIndicator::Idx((_first, idx)) => { - // Safety: - // Group tuples are always in bounds - let groups = sort_by_s - .iter() - .map(|s| unsafe { - s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) - }) - .collect::>(); - - let sorted_idx = - groups[0].argsort_multiple(&groups[1..], &reverse).unwrap(); - map_sorted_indices_to_group_idx(&sorted_idx, idx) - } - GroupsIndicator::Slice([first, len]) => { - let groups = sort_by_s - .iter() - .map(|s| s.slice(first as i64, len as usize)) - .collect::>(); - let sorted_idx = - groups[0].argsort_multiple(&groups[1..], &reverse).unwrap(); - map_sorted_indices_to_group_slice(&sorted_idx, first) - } - }; + let reverse = prepare_reverse(&self.reverse, self.by.len()); - (new_idx[0], new_idx) - }) - .collect(); + let (groups, ordered_by_group_operation) = if self.by.len() == 1 { + let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?; + let sort_by_s = ac_sort_by.flat_naive().into_owned(); - (GroupsProxy::Idx(groups), ordered_by_group_operation) - }; + let ordered_by_group_operation = matches!( + ac_sort_by.update_groups, + UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen + ); + let groups = ac_sort_by.groups(); - // if the rhs is already aggregated once, - // it is reordered by the groupby operation - // we must ensure that we are as well. - if ordered_by_group_operation { - let s = ac_in.aggregated(); - ac_in.with_series(s.explode().unwrap(), false); - } + let groups = groups + .par_iter() + .map(|indicator| { + let new_idx = match indicator { + GroupsIndicator::Idx((_, idx)) => { + // Safety: + // Group tuples are always in bounds + let group = unsafe { + sort_by_s + .take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) + }; + + let sorted_idx = group.argsort(SortOptions { + descending: reverse[0], + ..Default::default() + }); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + } + GroupsIndicator::Slice([first, len]) => { + let group = sort_by_s.slice(first as i64, len as usize); + let sorted_idx = group.argsort(SortOptions { + descending: reverse[0], + ..Default::default() + }); + map_sorted_indices_to_group_slice(&sorted_idx, first) + } + }; + + (new_idx[0], new_idx) + }) + .collect(); + + (GroupsProxy::Idx(groups), ordered_by_group_operation) + } else { + let mut ac_sort_by = self + .by + .iter() + .map(|e| e.evaluate_on_groups(df, groups, state)) + .collect::>>()?; + let sort_by_s = ac_sort_by + .iter() + .map(|s| s.flat_naive().into_owned()) + .collect::>(); - ac_in.with_groups(groups); - Ok(ac_in) + let ordered_by_group_operation = matches!( + ac_sort_by[0].update_groups, + UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen + ); + let groups = ac_sort_by[0].groups(); + + let groups = groups + .par_iter() + .map(|indicator| { + let new_idx = match indicator { + GroupsIndicator::Idx((_first, idx)) => { + // Safety: + // Group tuples are always in bounds + let groups = sort_by_s + .iter() + .map(|s| unsafe { + s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) + }) + .collect::>(); + + let sorted_idx = + groups[0].argsort_multiple(&groups[1..], &reverse).unwrap(); + map_sorted_indices_to_group_idx(&sorted_idx, idx) + } + GroupsIndicator::Slice([first, len]) => { + let groups = sort_by_s + .iter() + .map(|s| s.slice(first as i64, len as usize)) + .collect::>(); + let sorted_idx = + groups[0].argsort_multiple(&groups[1..], &reverse).unwrap(); + map_sorted_indices_to_group_slice(&sorted_idx, first) + } + }; + + (new_idx[0], new_idx) + }) + .collect(); + + (GroupsProxy::Idx(groups), ordered_by_group_operation) + }; + + // if the rhs is already aggregated once, + // it is reordered by the groupby operation + // we must ensure that we are as well. + if ordered_by_group_operation { + let s = ac_in.aggregated(); + ac_in.with_series(s.explode().unwrap(), false); + } + + ac_in.with_groups(groups); + Ok(ac_in) + } } fn to_field(&self, input_schema: &Schema) -> PolarsResult { diff --git a/py-polars/polars/testing/_private.py b/py-polars/polars/testing/_private.py index eea7ee9b74f30..7fc037dc87375 100644 --- a/py-polars/polars/testing/_private.py +++ b/py-polars/polars/testing/_private.py @@ -27,3 +27,32 @@ def verify_series_and_expr_api( else: assert_series_equal(result_expr, expected) assert_series_equal(result_series, expected) + + +def _to_rust_syntax(df: pli.DataFrame) -> str: + """ + Utility to generate the syntax that creates a polars + 'DataFrame' in Rust. + """ + syntax = "df![\n" + + def format_s(s: pli.Series) -> str: + if s.null_count() == 0: + return str(s.to_list()).replace("'", '"') + else: + tmp = "[" + for val in s: + if val is None: + tmp += "None, " + else: + if isinstance(val, str): + tmp += f'Some("{val}"), ' + else: + tmp += f"Some({val}), " + tmp = tmp[:-2] + "]" + return tmp + + for s in df: + syntax += f' "{s.name}" => {format_s(s)},\n' + syntax += "]" + return syntax diff --git a/py-polars/tests/unit/test_sort.py b/py-polars/tests/unit/test_sort.py index d62710f517485..ed11f39d206dc 100644 --- a/py-polars/tests/unit/test_sort.py +++ b/py-polars/tests/unit/test_sort.py @@ -322,3 +322,22 @@ def test_sorted_join_query_5406() -> None: pl.exclude(["Datetime_right", "Group_right"]) ) assert out["Value_right"].to_list() == [1, None, 2, 1, 2, None] + + +def test_sort_by_in_over_5499() -> None: + df = pl.DataFrame( + { + "group": [1, 1, 1, 2, 2, 2], + "idx": pl.arange(0, 6, eager=True), + "a": [1, 3, 2, 3, 1, 2], + } + ) + assert df.select( + [ + pl.col("idx").sort_by("a").over("group").alias("sorted_1"), + pl.col("idx").shift(1).sort_by("a").over("group").alias("sorted_2"), + ] + ).to_dict(False) == { + "sorted_1": [0, 2, 1, 4, 5, 3], + "sorted_2": [None, 1, 0, 3, 4, None], + }