diff --git a/crates/polars-ops/src/frame/pivot/mod.rs b/crates/polars-ops/src/frame/pivot/mod.rs index 68c34e8ac99b..fff3e297d504 100644 --- a/crates/polars-ops/src/frame/pivot/mod.rs +++ b/crates/polars-ops/src/frame/pivot/mod.rs @@ -77,11 +77,10 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series { } } -/// Determine `values` columns. +/// Determine `values` columns, which is optional in `pivot` calls. /// -/// When the optional `values` parameter is `None`, we use all remaining columns in the `DataFrame` -/// after `index` and `columns` have been excluded. When `values` is `Some`, we return a vector of -/// strings. +/// If not specified (i.e. is `None`, we use all remaining columns in the `DataFrame`)after `index` +/// and `columns` have been excluded. fn _get_values_columns( df: &DataFrame, index: &[String], @@ -98,20 +97,19 @@ where .map(|s| s.as_ref().to_string()) .collect::>(), None => { - let column_names = df.get_column_names_owned(); - let mut column_set = PlHashSet::::with_capacity(column_names.len()); + let mut column_set = PlHashSet::::with_capacity(index.len() + columns.len()); - // Column names are always unique. - column_names.into_iter().for_each(|s| { - column_set.insert_unique_unchecked(s.to_string()); - }); - - // Remove `index` and `column` columns. + // Hash columns we don't want to include index.iter().chain(columns.iter()).for_each(|s| { - column_set.remove(s); + column_set.insert_unique_unchecked(s.to_owned()); }); - column_set.drain().collect() + // filter out + df.get_column_names_owned() + .into_iter() + .map(|s| s.to_string()) + .filter(|s| !column_set.contains(s)) + .collect() }, } } diff --git a/py-polars/tests/unit/operations/test_pivot.py b/py-polars/tests/unit/operations/test_pivot.py index 2cba1ee326f9..e9f5617a6066 100644 --- a/py-polars/tests/unit/operations/test_pivot.py +++ b/py-polars/tests/unit/operations/test_pivot.py @@ -61,9 +61,7 @@ def test_pivot_no_values() -> None: } ) - # the order of the output columns is volatile - assert set(result.columns) == set(expected.columns) - assert_frame_equal(result, expected.select(result.columns)) + assert_frame_equal(result, expected) def test_pivot_list() -> None: