diff --git a/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 3d706383b12c..33dc5c164910 100644 --- a/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/polars/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -91,11 +91,24 @@ pub fn _get_rows_encoded( for (by, descending) in by.iter().zip(descending) { let arr = _get_rows_encoded_compat_array(by)?; - cols.push(arr); - fields.push(SortField { + let sort_field = SortField { descending: *descending, nulls_last, - }) + }; + match arr.data_type() { + // flatten the struct fields + ArrowDataType::Struct(_) => { + let arr = arr.as_any().downcast_ref::().unwrap(); + for arr in arr.values() { + cols.push(arr.clone() as ArrayRef); + fields.push(sort_field.clone()) + } + } + _ => { + cols.push(arr); + fields.push(sort_field) + } + } } Ok(convert_columns(&cols, &fields)) } diff --git a/polars/polars-core/src/chunked_array/ops/sort/mod.rs b/polars/polars-core/src/chunked_array/ops/sort/mod.rs index f55e370e0472..b38f8b307e34 100644 --- a/polars/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/polars/polars-core/src/chunked_array/ops/sort/mod.rs @@ -714,6 +714,16 @@ pub(crate) fn convert_sort_column_multi_sort( s.cast(&UInt8).unwrap() } } + #[cfg(feature = "dtype-struct")] + Struct(_) => { + let ca = s.struct_().unwrap(); + let new_fields = ca + .fields() + .iter() + .map(|s| convert_sort_column_multi_sort(s, row_ordering)) + .collect::>>()?; + return StructChunked::new(ca.name(), &new_fields).map(|ca| ca.into_series()); + } _ => { let phys = s.to_physical_repr().into_owned(); polars_ensure!( diff --git a/polars/polars-core/src/frame/mod.rs b/polars/polars-core/src/frame/mod.rs index 96769ea7c0c4..1fd7b927d242 100644 --- a/polars/polars-core/src/frame/mod.rs +++ b/polars/polars-core/src/frame/mod.rs @@ -1802,6 +1802,15 @@ impl DataFrame { return self.top_k_impl(k, descending, by_column, nulls_last); } + #[cfg(feature = "dtype-struct")] + let has_struct = by_column + .iter() + .any(|s| matches!(s.dtype(), DataType::Struct(_))); + + #[cfg(not(feature = "dtype-struct"))] + #[allow(non_upper_case_globals)] + const has_struct: bool = false; + // a lot of indirection in both sorting and take let mut df = self.clone(); let df = df.as_single_chunk_par(); @@ -1812,8 +1821,8 @@ impl DataFrame { // as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i. let first_descending = descending[0]; let first_by_column = by_column[0].name().to_string(); - let mut take = match by_column.len() { - 1 => { + let mut take = match (by_column.len(), has_struct) { + (1, false) => { let s = &by_column[0]; let options = SortOptions { descending: descending[0], @@ -1834,7 +1843,7 @@ impl DataFrame { s.arg_sort(options) } _ => { - if nulls_last || std::env::var("POLARS_ROW_FMT_SORT").is_ok() { + if nulls_last || has_struct || std::env::var("POLARS_ROW_FMT_SORT").is_ok() { argsort_multiple_row_fmt(&by_column, descending, nulls_last, parallel)? } else { let (first, by_column, descending) = prepare_arg_sort(by_column, descending)?; diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 9480913a37dd..3a07212d5695 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -554,3 +554,11 @@ def test_limit_larger_than_sort() -> None: assert pl.LazyFrame({"a": [1]}).sort("a").limit(30).collect().to_dict(False) == { "a": [1] } + + +def test_sort_by_struct() -> None: + df = pl.Series([{"a": 300}, {"a": 20}, {"a": 55}]).to_frame("st").with_row_count() + assert df.sort("st").to_dict(False) == { + "row_nr": [1, 2, 0], + "st": [{"a": 20}, {"a": 55}, {"a": 300}], + }