Skip to content

Commit

Permalink
fix: Fix group-by slice on all keys (#18324)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 23, 2024
1 parent 937855a commit 9828c41
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 19 deletions.
18 changes: 15 additions & 3 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1513,13 +1513,25 @@ impl LazyFrame {

/// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode).
pub fn explode<E: AsRef<[IE]>, IE: Into<Selector> + Clone>(self, columns: E) -> LazyFrame {
self.explode_impl(columns, false)
}

/// Apply explode operation. [See eager explode](polars_core::frame::DataFrame::explode).
fn explode_impl<E: AsRef<[IE]>, IE: Into<Selector> + Clone>(
self,
columns: E,
allow_empty: bool,
) -> LazyFrame {
let columns = columns
.as_ref()
.iter()
.map(|e| e.clone().into())
.collect::<Vec<_>>();
let opt_state = self.get_opt_state();
let lp = self.get_plan_builder().explode(columns).build();
let lp = self
.get_plan_builder()
.explode(columns, allow_empty)
.build();
Self::from_logical_plan(lp, opt_state)
}

Expand Down Expand Up @@ -1877,7 +1889,7 @@ impl LazyGroupBy {
.collect::<Vec<_>>();

self.agg([col("*").exclude(&keys).head(n)])
.explode([col("*").exclude(&keys)])
.explode_impl([col("*").exclude(&keys)], true)
}

/// Return last n rows of each group
Expand All @@ -1889,7 +1901,7 @@ impl LazyGroupBy {
.collect::<Vec<_>>();

self.agg([col("*").exclude(&keys).tail(n)])
.explode([col("*").exclude(&keys)])
.explode_impl([col("*").exclude(&keys)], true)
}

/// Apply a function over the groups as a new DataFrame.
Expand Down
7 changes: 5 additions & 2 deletions crates/polars-plan/src/plans/builder_dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,13 @@ impl DslBuilder {
.into()
}

pub fn explode(self, columns: Vec<Selector>) -> Self {
pub fn explode(self, columns: Vec<Selector>, allow_empty: bool) -> Self {
DslPlan::MapFunction {
input: Arc::new(self.0),
function: DslFunction::Explode { columns },
function: DslFunction::Explode {
columns,
allow_empty,
},
}
.into()
}
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,23 @@ pub fn to_alp_impl(
let input_schema = lp_arena.get(input).schema(lp_arena);

match function {
DslFunction::Explode {
columns,
allow_empty,
} => {
let columns = expand_selectors(columns, &input_schema, &[])?;
validate_columns_in_input(&columns, &input_schema, "explode")?;
polars_ensure!(!columns.is_empty() || allow_empty, InvalidOperation: "no columns provided in explode");
if columns.is_empty() {
return Ok(input);
}
let function = FunctionIR::Explode {
columns,
schema: Default::default(),
};
let ir = IR::MapFunction { input, function };
return Ok(lp_arena.add(ir));
},
DslFunction::FillNan(fill_value) => {
let exprs = input_schema
.iter()
Expand Down
24 changes: 10 additions & 14 deletions crates/polars-plan/src/plans/functions/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub enum DslFunction {
OpaquePython(OpaquePythonUdf),
Explode {
columns: Vec<Selector>,
allow_empty: bool,
},
#[cfg(feature = "pivot")]
Unpivot {
Expand Down Expand Up @@ -79,7 +80,7 @@ pub enum StatsFunction {
Max,
}

fn validate_columns<S: AsRef<str>>(
pub(crate) fn validate_columns_in_input<S: AsRef<str>>(
columns: &[S],
input_schema: &Schema,
operation_name: &str,
Expand All @@ -93,20 +94,12 @@ fn validate_columns<S: AsRef<str>>(
impl DslFunction {
pub(crate) fn into_function_ir(self, input_schema: &Schema) -> PolarsResult<FunctionIR> {
let function = match self {
DslFunction::Explode { columns } => {
let columns = expand_selectors(columns, input_schema, &[])?;
validate_columns(columns.as_ref(), input_schema, "explode")?;
FunctionIR::Explode {
columns,
schema: Default::default(),
}
},
#[cfg(feature = "pivot")]
DslFunction::Unpivot { args } => {
let on = expand_selectors(args.on, input_schema, &[])?;
let index = expand_selectors(args.index, input_schema, &[])?;
validate_columns(on.as_ref(), input_schema, "unpivot")?;
validate_columns(index.as_ref(), input_schema, "unpivot")?;
validate_columns_in_input(on.as_ref(), input_schema, "unpivot")?;
validate_columns_in_input(index.as_ref(), input_schema, "unpivot")?;

let args = UnpivotArgsIR {
on: on.iter().map(|s| s.as_ref().into()).collect(),
Expand All @@ -128,7 +121,7 @@ impl DslFunction {
},
DslFunction::Rename { existing, new } => {
let swapping = new.iter().any(|name| input_schema.get(name).is_some());
validate_columns(existing.as_ref(), input_schema, "rename")?;
validate_columns_in_input(existing.as_ref(), input_schema, "rename")?;

FunctionIR::Rename {
existing,
Expand All @@ -139,12 +132,15 @@ impl DslFunction {
},
DslFunction::Unnest(selectors) => {
let columns = expand_selectors(selectors, input_schema, &[])?;
validate_columns(columns.as_ref(), input_schema, "explode")?;
validate_columns_in_input(columns.as_ref(), input_schema, "explode")?;
FunctionIR::Unnest { columns }
},
#[cfg(feature = "python")]
DslFunction::OpaquePython(inner) => FunctionIR::OpaquePython(inner),
DslFunction::Stats(_) | DslFunction::FillNan(_) | DslFunction::Drop(_) => {
DslFunction::Stats(_)
| DslFunction::FillNan(_)
| DslFunction::Drop(_)
| DslFunction::Explode { .. } => {
// We should not reach this.
panic!("impl error")
},
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/operations/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,23 @@ def test_slice_pushdown_simple_projection_18288() -> None:
"col": [0],
"literal": [None],
}


def test_group_by_slice_all_keys() -> None:
df = pl.DataFrame(
{
"a": ["Tom", "Nick", "Marry", "Krish", "Jack", None],
"b": [
"2020-01-01",
"2020-01-02",
"2020-01-03",
"2020-01-04",
"2020-01-05",
None,
],
"c": [5, 6, 6, 7, 8, 5],
}
)

gb = df.group_by(["a", "b", "c"], maintain_order=True)
assert_frame_equal(gb.tail(1), gb.head(1))

0 comments on commit 9828c41

Please sign in to comment.