From e2f80f10e0449815f943a55dc9c1feffc433c8d6 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 24 Oct 2022 15:23:48 +0200 Subject: [PATCH] =?UTF-8?q?fix(rust,=20python):=20fix=20explicit=20list=20?= =?UTF-8?q?+=20sort=20aggregation=20in=20groupby=20co=E2=80=A6=20(#5317)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- polars/Cargo.toml | 3 - polars/polars-lazy/Cargo.toml | 4 +- polars/polars-lazy/polars-plan/Cargo.toml | 1 - .../polars-plan/src/dsl/function_expr/list.rs | 1 + .../polars-plan/src/dsl/function_expr/mod.rs | 5 -- .../src/dsl/function_expr/schema.rs | 3 - .../polars-plan/src/dsl/functions.rs | 3 - .../polars-lazy/polars-plan/src/dsl/list.rs | 5 +- polars/polars-lazy/polars-plan/src/dsl/mod.rs | 3 - polars/polars-lazy/src/dsl/list.rs | 5 ++ polars/polars-lazy/src/dsl/mod.rs | 2 - .../src/physical_plan/expressions/sort.rs | 66 +++++++++++-------- polars/polars-ops/Cargo.toml | 3 +- .../polars-ops/src/chunked_array/list/mod.rs | 3 - .../src/chunked_array/list/namespace.rs | 5 +- py-polars/Cargo.toml | 1 - py-polars/src/lazy/dsl.rs | 5 +- py-polars/tests/unit/test_sort.py | 10 +++ 18 files changed, 65 insertions(+), 63 deletions(-) diff --git a/polars/Cargo.toml b/polars/Cargo.toml index 0a477cd6d2a33..e53b20e58cd31 100644 --- a/polars/Cargo.toml +++ b/polars/Cargo.toml @@ -90,7 +90,6 @@ lazy_regex = ["polars-lazy/regex"] cum_agg = ["polars-core/cum_agg", "polars-core/cum_agg"] rolling_window = ["polars-core/rolling_window", "polars-lazy/rolling_window", "polars-time/rolling_window"] interpolate = ["polars-core/interpolate", "polars-lazy/interpolate"] -list = ["polars-lazy/list", "polars-ops/list"] rank = ["polars-core/rank", "polars-lazy/rank"] diff = ["polars-core/diff", "polars-lazy/diff", "polars-ops/diff"] pct_change = ["polars-core/pct_change", "polars-lazy/pct_change"] @@ -136,7 +135,6 @@ test = [ "private", "rolling_window", "rank", - "list", "round_series", "csv-file", "dtype-categorical", @@ -254,7 +252,6 @@ docs-selection = [ "interpolate", "diff", "rank", - "list", "arange", "diagonal_concat", "horizontal_concat", diff --git a/polars/polars-lazy/Cargo.toml b/polars/polars-lazy/Cargo.toml index 2dc0e311c726a..490d5aeb42231 100644 --- a/polars/polars-lazy/Cargo.toml +++ b/polars/polars-lazy/Cargo.toml @@ -57,7 +57,7 @@ timezones = ["polars-plan/timezones"] true_div = ["polars-plan/true_div"] # operations -is_in = ["polars-plan/is_in", "list"] +is_in = ["polars-plan/is_in"] repeat_by = ["polars-plan/repeat_by"] round_series = ["polars-plan/round_series"] is_first = ["polars-plan/is_first"] @@ -77,7 +77,6 @@ rank = ["polars-plan/rank"] diff = ["polars-plan/diff", "polars-plan/diff"] pct_change = ["polars-plan/pct_change"] moment = ["polars-plan/moment"] -list = ["polars-plan/list"] abs = ["polars-plan/abs"] random = ["polars-plan/random"] dynamic_groupby = ["polars-plan/dynamic_groupby", "polars-time", "temporal"] @@ -116,7 +115,6 @@ test = [ "private", "rolling_window", "rank", - "list", "round_series", "csv-file", "dtype-categorical", diff --git a/polars/polars-lazy/polars-plan/Cargo.toml b/polars/polars-lazy/polars-plan/Cargo.toml index 186f74366e344..72b75688b21cb 100644 --- a/polars/polars-lazy/polars-plan/Cargo.toml +++ b/polars/polars-lazy/polars-plan/Cargo.toml @@ -75,7 +75,6 @@ rank = ["polars-core/rank"] diff = ["polars-core/diff", "polars-ops/diff"] pct_change = ["polars-core/pct_change"] moment = ["polars-core/moment"] -list = ["polars-ops/list"] abs = ["polars-core/abs"] random = ["polars-core/random"] dynamic_groupby = ["polars-core/dynamic_groupby"] diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs index 072189b90874e..5d81b9cd83117 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs @@ -17,6 +17,7 @@ impl Display for ListFunction { let name = match self { Concat => "concat", + #[cfg(feature = "is_in")] Contains => "contains", Slice => "slice", }; diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs index e69917da549d6..fe5797876cd20 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/mod.rs @@ -8,7 +8,6 @@ mod dispatch; mod fill_null; #[cfg(feature = "is_in")] mod is_in; -#[cfg(any(feature = "is_in", feature = "list"))] mod list; mod nan; mod pow; @@ -33,7 +32,6 @@ mod trigonometry; use std::fmt::{Display, Formatter}; -#[cfg(feature = "list")] pub(super) use list::ListFunction; use polars_core::prelude::*; #[cfg(feature = "serde")] @@ -91,7 +89,6 @@ pub enum FunctionExpr { min: Option>, max: Option>, }, - #[cfg(feature = "list")] ListExpr(ListFunction), #[cfg(feature = "dtype-struct")] StructExpr(StructFunction), @@ -147,7 +144,6 @@ impl Display for FunctionExpr { (Some(_), None) => "clip_min", _ => unreachable!(), }, - #[cfg(feature = "list")] ListExpr(func) => return write!(f, "{}", func), #[cfg(feature = "dtype-struct")] StructExpr(func) => return write!(f, "{}", func), @@ -301,7 +297,6 @@ impl From for SpecialEq> { Clip { min, max } => { map_owned!(clip::clip, min.clone(), max.clone()) } - #[cfg(feature = "list")] ListExpr(lf) => { use ListFunction::*; match lf { diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs index 8b257efae696c..569a046b9baf3 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs @@ -25,7 +25,6 @@ impl FunctionExpr { }; // map all dtypes - #[cfg(feature = "list")] let map_dtypes = |func: &dyn Fn(&[&DataType]) -> DataType| { let mut fld = fields[0].clone(); let dtypes = fields.iter().map(|fld| fld.data_type()).collect::>(); @@ -58,7 +57,6 @@ impl FunctionExpr { }; // inner super type of lists - #[cfg(feature = "list")] let inner_super_type_list = || { map_dtypes(&|dts| { let mut super_type_inner = None; @@ -157,7 +155,6 @@ impl FunctionExpr { Nan(n) => n.get_field(fields), #[cfg(feature = "round_series")] Clip { .. } => same_type(), - #[cfg(feature = "list")] ListExpr(l) => { use ListFunction::*; match l { diff --git a/polars/polars-lazy/polars-plan/src/dsl/functions.rs b/polars/polars-lazy/polars-plan/src/dsl/functions.rs index c52181a8aef2d..0cf8676d13cbf 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/functions.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/functions.rs @@ -13,7 +13,6 @@ use polars_core::utils::get_supertype; #[cfg(feature = "arg_where")] use crate::dsl::function_expr::FunctionExpr; -#[cfg(feature = "list")] use crate::dsl::function_expr::ListFunction; #[cfg(feature = "strings")] use crate::dsl::function_expr::StringFunction; @@ -300,8 +299,6 @@ pub fn format_str>(format: &str, args: E) -> PolarsResult } /// Concat lists entries. -#[cfg(feature = "list")] -#[cfg_attr(docsrs, doc(cfg(feature = "list")))] pub fn concat_lst, IE: Into + Clone>(s: E) -> Expr { let s = s.as_ref().iter().map(|e| e.clone().into()).collect(); diff --git a/polars/polars-lazy/polars-plan/src/dsl/list.rs b/polars/polars-lazy/polars-plan/src/dsl/list.rs index e4c78283ba846..db9e21c4861bb 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/list.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/list.rs @@ -1,4 +1,5 @@ use polars_core::prelude::*; +#[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use polars_ops::prelude::*; @@ -83,10 +84,10 @@ impl ListNameSpace { } /// Sort every sublist. - pub fn sort(self, reverse: bool) -> Expr { + pub fn sort(self, options: SortOptions) -> Expr { self.0 .map( - move |s| Ok(s.list()?.lst_sort(reverse).into_series()), + move |s| Ok(s.list()?.lst_sort(options).into_series()), GetOutput::same_type(), ) .with_fmt("arr.sort") diff --git a/polars/polars-lazy/polars-plan/src/dsl/mod.rs b/polars/polars-lazy/polars-plan/src/dsl/mod.rs index 372756f43571b..6e0bdc7cd06ae 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/mod.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/mod.rs @@ -10,7 +10,6 @@ mod from; pub(crate) mod function_expr; #[cfg(feature = "compile")] pub mod functions; -#[cfg(feature = "list")] mod list; #[cfg(feature = "meta")] mod meta; @@ -28,7 +27,6 @@ use std::sync::Arc; pub use expr::*; pub use function_expr::*; pub use functions::*; -#[cfg(feature = "list")] pub use list::*; pub use options::*; use polars_arrow::prelude::QuantileInterpolOptions; @@ -2253,7 +2251,6 @@ impl Expr { pub fn dt(self) -> dt::DateLikeNameSpace { dt::DateLikeNameSpace(self) } - #[cfg(feature = "list")] pub fn arr(self) -> list::ListNameSpace { list::ListNameSpace(self) } diff --git a/polars/polars-lazy/src/dsl/list.rs b/polars/polars-lazy/src/dsl/list.rs index d8c906cdd041e..50b954b9bb5b8 100644 --- a/polars/polars-lazy/src/dsl/list.rs +++ b/polars/polars-lazy/src/dsl/list.rs @@ -1,8 +1,13 @@ +#[cfg(feature = "list_eval")] use std::sync::Mutex; +#[cfg(feature = "list_eval")] use polars_arrow::utils::CustomIterTools; +#[cfg(feature = "list_eval")] use polars_core::prelude::*; +#[cfg(feature = "list_eval")] use polars_plan::dsl::*; +#[cfg(feature = "list_eval")] use rayon::prelude::*; use crate::prelude::*; diff --git a/polars/polars-lazy/src/dsl/mod.rs b/polars/polars-lazy/src/dsl/mod.rs index 215b6ba287eb4..12083524fdeac 100644 --- a/polars/polars-lazy/src/dsl/mod.rs +++ b/polars/polars-lazy/src/dsl/mod.rs @@ -3,7 +3,6 @@ mod eval; pub mod functions; mod into; -#[cfg(feature = "list")] mod list; #[cfg(feature = "cumulative_eval")] @@ -11,7 +10,6 @@ pub use eval::*; pub use functions::*; #[cfg(feature = "cumulative_eval")] use into::IntoExpr; -#[cfg(feature = "list")] pub use list::*; pub use polars_plan::dsl::*; pub use polars_plan::logical_plan::UdfSchema; diff --git a/polars/polars-lazy/src/physical_plan/expressions/sort.rs b/polars/polars-lazy/src/physical_plan/expressions/sort.rs index cf8e2692cda41..ece38ef7895f1 100644 --- a/polars/polars-lazy/src/physical_plan/expressions/sort.rs +++ b/polars/polars-lazy/src/physical_plan/expressions/sort.rs @@ -65,38 +65,46 @@ impl PhysicalExpr for SortExpr { state: &ExecutionState, ) -> PolarsResult> { let mut ac = self.physical_expr.evaluate_on_groups(df, groups, state)?; - let series = ac.flat_naive().into_owned(); + match ac.agg_state() { + AggState::AggregatedList(s) => { + let ca = s.list().unwrap(); + let out = ca.lst_sort(self.options); + ac.with_series(out.into_series(), true); + } + _ => { + let series = ac.flat_naive().into_owned(); - let groups = match ac.groups().as_ref() { - GroupsProxy::Idx(groups) => { - groups - .iter() - .map(|(first, idx)| { - // Safety: - // Group tuples are always in bounds - let group = unsafe { - series.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) - }; + let groups = match ac.groups().as_ref() { + GroupsProxy::Idx(groups) => { + groups + .iter() + .map(|(first, idx)| { + // Safety: + // Group tuples are always in bounds + let group = unsafe { + series.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) + }; - let sorted_idx = group.argsort(self.options); - let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx); - (new_idx.first().copied().unwrap_or(first), new_idx) - }) - .collect() + let sorted_idx = group.argsort(self.options); + let new_idx = map_sorted_indices_to_group_idx(&sorted_idx, idx); + (new_idx.first().copied().unwrap_or(first), new_idx) + }) + .collect() + } + GroupsProxy::Slice { groups, .. } => groups + .iter() + .map(|&[first, len]| { + let group = series.slice(first as i64, len as usize); + let sorted_idx = group.argsort(self.options); + let new_idx = map_sorted_indices_to_group_slice(&sorted_idx, first); + (new_idx.first().copied().unwrap_or(first), new_idx) + }) + .collect(), + }; + let groups = GroupsProxy::Idx(groups); + ac.with_groups(groups); } - GroupsProxy::Slice { groups, .. } => groups - .iter() - .map(|&[first, len]| { - let group = series.slice(first as i64, len as usize); - let sorted_idx = group.argsort(self.options); - let new_idx = map_sorted_indices_to_group_slice(&sorted_idx, first); - (new_idx.first().copied().unwrap_or(first), new_idx) - }) - .collect(), - }; - let groups = GroupsProxy::Idx(groups); - - ac.with_groups(groups); + } Ok(ac) } diff --git a/polars/polars-ops/Cargo.toml b/polars/polars-ops/Cargo.toml index 4aafe8e2b2c23..d1c44384d64d2 100644 --- a/polars/polars-ops/Cargo.toml +++ b/polars/polars-ops/Cargo.toml @@ -32,8 +32,7 @@ propagate_nans = [] # ops to_dummies = [] -list_to_struct = ["polars-core/dtype-struct", "list"] -list = [] +list_to_struct = ["polars-core/dtype-struct"] diff = ["polars-core/diff"] strings = ["polars-core/strings"] string_justify = ["polars-core/strings"] diff --git a/polars/polars-ops/src/chunked_array/list/mod.rs b/polars/polars-ops/src/chunked_array/list/mod.rs index c839e5099bf07..74b35f7b7abe1 100644 --- a/polars/polars-ops/src/chunked_array/list/mod.rs +++ b/polars/polars-ops/src/chunked_array/list/mod.rs @@ -2,13 +2,10 @@ use polars_core::prelude::*; #[cfg(feature = "hash")] pub(crate) mod hash; -#[cfg(feature = "list")] -#[cfg_attr(docsrs, doc(cfg(feature = "list")))] mod namespace; #[cfg(feature = "list_to_struct")] mod to_struct; -#[cfg(feature = "list")] pub use namespace::*; #[cfg(feature = "list_to_struct")] pub use to_struct::*; diff --git a/polars/polars-ops/src/chunked_array/list/namespace.rs b/polars/polars-ops/src/chunked_array/list/namespace.rs index ec7949912bfbf..c7c41cc97a29f 100644 --- a/polars/polars-ops/src/chunked_array/list/namespace.rs +++ b/polars/polars-ops/src/chunked_array/list/namespace.rs @@ -4,6 +4,7 @@ use std::fmt::Write; use polars_arrow::kernels::list::sublist_get; use polars_arrow::prelude::ValueSize; use polars_core::chunked_array::builder::get_list_builder; +#[cfg(feature = "diff")] use polars_core::series::ops::NullBehavior; use polars_core::utils::{try_get_supertype, CustomIterTools}; @@ -133,9 +134,9 @@ pub trait ListNameSpaceImpl: AsList { } #[must_use] - fn lst_sort(&self, reverse: bool) -> ListChunked { + fn lst_sort(&self, options: SortOptions) -> ListChunked { let ca = self.as_list(); - ca.apply_amortized(|s| s.as_ref().sort(reverse)) + ca.apply_amortized(|s| s.as_ref().sort_with(options)) } #[must_use] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 25b9b5d674577..65b77d88461de 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -128,7 +128,6 @@ features = [ "cum_agg", "rolling_window", "interpolate", - "list", "rank", "diff", "moment", diff --git a/py-polars/src/lazy/dsl.rs b/py-polars/src/lazy/dsl.rs index f52443460bc01..98a675e89d302 100644 --- a/py-polars/src/lazy/dsl.rs +++ b/py-polars/src/lazy/dsl.rs @@ -1302,7 +1302,10 @@ impl PyExpr { self.inner .clone() .arr() - .sort(reverse) + .sort(SortOptions { + descending: reverse, + ..Default::default() + }) .with_fmt("arr.sort") .into() } diff --git a/py-polars/tests/unit/test_sort.py b/py-polars/tests/unit/test_sort.py index 28afd81f733eb..01f4ab1c38f02 100644 --- a/py-polars/tests/unit/test_sort.py +++ b/py-polars/tests/unit/test_sort.py @@ -278,3 +278,13 @@ def test_sort_slice_fast_path_5245() -> None: assert df.sort("foo").limit(1).select("foo").collect().to_dict(False) == { "foo": ["a"] } + + +def test_explicit_list_agg_sort_in_groupby() -> None: + df = pl.DataFrame({"A": ["a", "a", "a", "b", "b", "a"], "B": [1, 2, 3, 4, 5, 6]}) + assert ( + df.groupby("A") + .agg(pl.col("B").list().sort(reverse=True)) + .sort("A") + .frame_equal(df.groupby("A").agg(pl.col("B").sort(reverse=True)).sort("A")) + )