Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): fix sort_by expression if groups already aggregated #5518

Merged
merged 1 commit into from
Nov 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 135 additions & 95 deletions polars/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,113 +97,153 @@ impl PhysicalExpr for SortByExpr {
) -> PolarsResult<AggregationContext<'a>> {
let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?;

let reverse = prepare_reverse(&self.reverse, self.by.len());

let (groups, ordered_by_group_operation) = if self.by.len() == 1 {
// the groups of the lhs of the expressions do not match the series values
// we must take the slower path.
if !matches!(ac_in.update_groups, UpdateGroups::No) {
if self.by.len() > 1 {
let msg = "This expression is not yet supported for more than two sort columns. \
Consider opeingin a feature request.";
return Err(expression_err!(msg, self.expr, ComputeError));
}
let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?;
let sort_by_s = ac_sort_by.flat_naive().into_owned();
let sort_by = ac_sort_by.aggregated();
let mut sort_by = sort_by.list().unwrap().clone();
let s = ac_in.aggregated();
let mut s = s.list().unwrap().clone();

let ordered_by_group_operation = matches!(
ac_sort_by.update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by.groups();

let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// Safety:
// Group tuples are always in bounds
let group = unsafe {
sort_by_s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
};

let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
let descending = self.reverse[0];
let mut ca: ListChunked = s
.par_iter_indexed()
.zip(sort_by.par_iter_indexed())
.map(|(opt_s, s_sort_by)| match (opt_s, s_sort_by) {
(Some(s), Some(s_sort_by)) => {
if s.len() != s_sort_by.len() {
None
} else {
let idx = s_sort_by.argsort(SortOptions {
descending,
..Default::default()
});
map_sorted_indices_to_group_idx(&sorted_idx, idx)
Some(unsafe { s.take_unchecked(&idx).unwrap() })
}
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
..Default::default()
});
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};

(new_idx[0], new_idx)
}
_ => None,
})
.collect();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
ca.rename(s.name());
let s = ca.into_series();
ac_in.with_series(s, true);
Ok(ac_in)
} else {
let mut ac_sort_by = self
.by
.iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect::<PolarsResult<Vec<_>>>()?;
let sort_by_s = ac_sort_by
.iter()
.map(|s| s.flat_naive().into_owned())
.collect::<Vec<_>>();

let ordered_by_group_operation = matches!(
ac_sort_by[0].update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by[0].groups();

let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_first, idx)) => {
// Safety:
// Group tuples are always in bounds
let groups = sort_by_s
.iter()
.map(|s| unsafe {
s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
})
.collect::<Vec<_>>();

let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_idx(&sorted_idx, idx)
}
GroupsIndicator::Slice([first, len]) => {
let groups = sort_by_s
.iter()
.map(|s| s.slice(first as i64, len as usize))
.collect::<Vec<_>>();
let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};
let reverse = prepare_reverse(&self.reverse, self.by.len());

(new_idx[0], new_idx)
})
.collect();
let (groups, ordered_by_group_operation) = if self.by.len() == 1 {
let mut ac_sort_by = self.by[0].evaluate_on_groups(df, groups, state)?;
let sort_by_s = ac_sort_by.flat_naive().into_owned();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
};
let ordered_by_group_operation = matches!(
ac_sort_by.update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by.groups();

// if the rhs is already aggregated once,
// it is reordered by the groupby operation
// we must ensure that we are as well.
if ordered_by_group_operation {
let s = ac_in.aggregated();
ac_in.with_series(s.explode().unwrap(), false);
}
let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_, idx)) => {
// Safety:
// Group tuples are always in bounds
let group = unsafe {
sort_by_s
.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
};

let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
..Default::default()
});
map_sorted_indices_to_group_idx(&sorted_idx, idx)
}
GroupsIndicator::Slice([first, len]) => {
let group = sort_by_s.slice(first as i64, len as usize);
let sorted_idx = group.argsort(SortOptions {
descending: reverse[0],
..Default::default()
});
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};

(new_idx[0], new_idx)
})
.collect();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
} else {
let mut ac_sort_by = self
.by
.iter()
.map(|e| e.evaluate_on_groups(df, groups, state))
.collect::<PolarsResult<Vec<_>>>()?;
let sort_by_s = ac_sort_by
.iter()
.map(|s| s.flat_naive().into_owned())
.collect::<Vec<_>>();

ac_in.with_groups(groups);
Ok(ac_in)
let ordered_by_group_operation = matches!(
ac_sort_by[0].update_groups,
UpdateGroups::WithSeriesLen | UpdateGroups::WithGroupsLen
);
let groups = ac_sort_by[0].groups();

let groups = groups
.par_iter()
.map(|indicator| {
let new_idx = match indicator {
GroupsIndicator::Idx((_first, idx)) => {
// Safety:
// Group tuples are always in bounds
let groups = sort_by_s
.iter()
.map(|s| unsafe {
s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize))
})
.collect::<Vec<_>>();

let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_idx(&sorted_idx, idx)
}
GroupsIndicator::Slice([first, len]) => {
let groups = sort_by_s
.iter()
.map(|s| s.slice(first as i64, len as usize))
.collect::<Vec<_>>();
let sorted_idx =
groups[0].argsort_multiple(&groups[1..], &reverse).unwrap();
map_sorted_indices_to_group_slice(&sorted_idx, first)
}
};

(new_idx[0], new_idx)
})
.collect();

(GroupsProxy::Idx(groups), ordered_by_group_operation)
};

// if the rhs is already aggregated once,
// it is reordered by the groupby operation
// we must ensure that we are as well.
if ordered_by_group_operation {
let s = ac_in.aggregated();
ac_in.with_series(s.explode().unwrap(), false);
}

ac_in.with_groups(groups);
Ok(ac_in)
}
}

fn to_field(&self, input_schema: &Schema) -> PolarsResult<Field> {
Expand Down
26 changes: 26 additions & 0 deletions py-polars/polars/testing/_private.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,29 @@ def verify_series_and_expr_api(
else:
assert_series_equal(result_expr, expected)
assert_series_equal(result_series, expected)


def _to_rust_syntax(df: pli.DataFrame) -> str:
"""Utility to generate the syntax that creates a polars 'DataFrame' in Rust."""
syntax = "df![\n"

def format_s(s: pli.Series) -> str:
if s.null_count() == 0:
return str(s.to_list()).replace("'", '"')
else:
tmp = "["
for val in s:
if val is None:
tmp += "None, "
else:
if isinstance(val, str):
tmp += f'Some("{val}"), '
else:
tmp += f"Some({val}), "
tmp = tmp[:-2] + "]"
return tmp

for s in df:
syntax += f' "{s.name}" => {format_s(s)},\n'
syntax += "]"
return syntax
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,22 @@ def test_sorted_join_query_5406() -> None:
pl.exclude(["Datetime_right", "Group_right"])
)
assert out["Value_right"].to_list() == [1, None, 2, 1, 2, None]


def test_sort_by_in_over_5499() -> None:
df = pl.DataFrame(
{
"group": [1, 1, 1, 2, 2, 2],
"idx": pl.arange(0, 6, eager=True),
"a": [1, 3, 2, 3, 1, 2],
}
)
assert df.select(
[
pl.col("idx").sort_by("a").over("group").alias("sorted_1"),
pl.col("idx").shift(1).sort_by("a").over("group").alias("sorted_2"),
]
).to_dict(False) == {
"sorted_1": [0, 2, 1, 4, 5, 3],
"sorted_2": [None, 1, 0, 3, 4, None],
}