Skip to content

Commit

Permalink
fix(rust, python): fix sort_by expression if groups already aggregated (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Nov 16, 2022
1 parent 2b27e58 commit 5d096fb
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 95 deletions.
230 changes: 135 additions & 95 deletions polars/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,113 +97,153 @@ impl PhysicalExpr for SortByExpr {
) -> PolarsResult<AggregationContext<'a>> {
let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;

let reverse = prepare_reverse(&self.reverse, self.by.len());

let (groups, ordered_by_group_operation) = if self.by.len() == 1 {
// the groups of the lhs of the expressions do not match the series values
// we must take the slower path.
if !matches!(ac_in.update_groups, UpdateGroups::No) {
if self.by.len() > 1 {
let msg = "This expression is not yet supported for more than two sort columns. \
Consider opeingin a feature request.";
return Err(expression_err!(msg, self.expr, ComputeError));
}
let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?;
let sort_by_s = ac_sort_by.flat_naive().into_owned();
let sort_by = ac_sort_by.aggregated();
let mut sort_by = sort_by.list().unwrap().clone();
let s = ac_in.aggregated();
let mut s = s.list().unwrap().clone();

let ordered_by_group_operation = matches!(
ac_sort_by.update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by.groups();

let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// Safety:
// Group tuples are always in bounds
let group = unsafe {
sort_by_s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
};

let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
let descending = self.reverse[0];
let mut ca: ListChunked = s
.par_iter_indexed()
.zip(sort_by.par_iter_indexed())
.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {
(Some(s), Some(s_sort_by)) => {
if s.len() != s_sort_by.len() {
None
} else {
let idx = s_sort_by.argsort(SortOptions {
descending,
..Default::default()
});
map_sorted_indices_to_group_idx(&sorted_idx, idx)
Some(unsafe { s.take_unchecked(&idx).unwrap() })
}
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
..Default::default()
});
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};

(new_idx[0], new_idx)
}
_ => None,
})
.collect();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
ca.rename(s.name());
let s = ca.into_series();
ac_in.with_series(s, true);
Ok(ac_in)
} else {
let mut ac_sort_by = self
.by
.iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect::<PolarsResult<Vec<_>>>()?;
let sort_by_s = ac_sort_by
.iter()
.map(|s| s.flat_naive().into_owned())
.collect::<Vec<_>>();

let ordered_by_group_operation = matches!(
ac_sort_by[0].update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by[0].groups();

let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_first, idx)) => {
// Safety:
// Group tuples are always in bounds
let groups = sort_by_s
.iter()
.map(|s| unsafe {
s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
})
.collect::<Vec<_>>();

let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_idx(&sorted_idx, idx)
}
GroupsIndicator::Slice([first, len]) => {
let groups = sort_by_s
.iter()
.map(|s| s.slice(first as i64, len as usize))
.collect::<Vec<_>>();
let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};
let reverse = prepare_reverse(&self.reverse, self.by.len());

(new_idx[0], new_idx)
})
.collect();
let (groups, ordered_by_group_operation) = if self.by.len() == 1 {
let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?;
let sort_by_s = ac_sort_by.flat_naive().into_owned();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
};
let ordered_by_group_operation = matches!(
ac_sort_by.update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by.groups();

// if the rhs is already aggregated once,
// it is reordered by the groupby operation
// we must ensure that we are as well.
if ordered_by_group_operation {
let s = ac_in.aggregated();
ac_in.with_series(s.explode().unwrap(), false);
}
let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// Safety:
// Group tuples are always in bounds
let group = unsafe {
sort_by_s
.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
};

let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
..Default::default()
});
map_sorted_indices_to_group_idx(&sorted_idx, idx)
}
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
..Default::default()
});
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};

(new_idx[0], new_idx)
})
.collect();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
} else {
let mut ac_sort_by = self
.by
.iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect::<PolarsResult<Vec<_>>>()?;
let sort_by_s = ac_sort_by
.iter()
.map(|s| s.flat_naive().into_owned())
.collect::<Vec<_>>();

ac_in.with_groups(groups);
Ok(ac_in)
let ordered_by_group_operation = matches!(
ac_sort_by[0].update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by[0].groups();

let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_first, idx)) => {
// Safety:
// Group tuples are always in bounds
let groups = sort_by_s
.iter()
.map(|s| unsafe {
s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
})
.collect::<Vec<_>>();

let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_idx(&sorted_idx, idx)
}
GroupsIndicator::Slice([first, len]) => {
let groups = sort_by_s
.iter()
.map(|s| s.slice(first as i64, len as usize))
.collect::<Vec<_>>();
let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};

(new_idx[0], new_idx)
})
.collect();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
};

// if the rhs is already aggregated once,
// it is reordered by the groupby operation
// we must ensure that we are as well.
if ordered_by_group_operation {
let s = ac_in.aggregated();
ac_in.with_series(s.explode().unwrap(), false);
}

ac_in.with_groups(groups);
Ok(ac_in)
}
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
Expand Down
26 changes: 26 additions & 0 deletions py-polars/polars/testing/_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,29 @@ def verify_series_and_expr_api(
else:
assert_series_equal(result_expr, expected)
assert_series_equal(result_series, expected)


def _to_rust_syntax(df: pli.DataFrame) -> str:
"""Utility to generate the syntax that creates a polars 'DataFrame' in Rust."""
syntax = "df![\n"

def format_s(s: pli.Series) -> str:
if s.null_count() == 0:
return str(s.to_list()).replace("'", '"')
else:
tmp = "["
for val in s:
if val is None:
tmp += "None, "
else:
if isinstance(val, str):
tmp += f'Some("{val}"), '
else:
tmp += f"Some({val}), "
tmp = tmp[:-2] + "]"
return tmp

for s in df:
syntax += f' "{s.name}" => {format_s(s)},\n'
syntax += "]"
return syntax
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,22 @@ def test_sorted_join_query_5406() -> None:
pl.exclude(["Datetime_right", "Group_right"])
)
assert out["Value_right"].to_list() == [1, None, 2, 1, 2, None]


def test_sort_by_in_over_5499() -> None:
df = pl.DataFrame(
{
"group": [1, 1, 1, 2, 2, 2],
"idx": pl.arange(0, 6, eager=True),
"a": [1, 3, 2, 3, 1, 2],
}
)
assert df.select(
[
pl.col("idx").sort_by("a").over("group").alias("sorted_1"),
pl.col("idx").shift(1).sort_by("a").over("group").alias("sorted_2"),
]
).to_dict(False) == {
"sorted_1": [0, 2, 1, 4, 5, 3],
"sorted_2": [None, 1, 0, 3, 4, None],
}

0 comments on commit 5d096fb

Please sign in to comment.