Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(rust, python): improve single argument elementwise expression pe… #7180

Merged
merged 1 commit into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions polars/polars-core/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,36 @@ impl ListChunked {
fld.coerce(DataType::List(Box::new(inner_dtype)))
}

/// Ignore the list indices and apply `func` to the inner type as `Series`.
pub fn apply_to_inner(
&self,
func: &dyn Fn(Series) -> PolarsResult<Series>,
) -> PolarsResult<ListChunked> {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let elements = Series::try_from(("", arr.values().clone())).unwrap();

let expected_len = elements.len();
let out: Series = func(elements)?;
if out.len() != expected_len {
return Err(PolarsError::ComputeError(
"The function should apply only elementwise Instead it has removed elements".into(),
));
}
let out = out.rechunk();
let values = out.chunks()[0].clone();

let inner_dtype = LargeListArray::default_datatype(out.dtype().to_arrow());
let arr = LargeListArray::new(
inner_dtype,
(*arr.offsets()).clone(),
values,
arr.validity().cloned(),
);
unsafe { Ok(ListChunked::from_chunks(self.name(), vec![Box::new(arr)])) }
let inner_dtype = self.inner_dtype().to_arrow();

let chunks = self.downcast_iter().map(|arr| {
let elements = unsafe { Series::try_from_arrow_unchecked("", vec![(*arr.values()).clone()], &inner_dtype).unwrap() } ;

let expected_len = elements.len();
let out: Series = func(elements)?;
if out.len() != expected_len {
return Err(PolarsError::ComputeError(
"The function should apply only elementwise Instead it has removed elements".into(),
));
}
let out = out.rechunk();
let values = out.chunks()[0].clone();

let inner_dtype = LargeListArray::default_datatype(out.dtype().to_arrow());
let arr = LargeListArray::new(
inner_dtype,
(*arr.offsets()).clone(),
values,
arr.validity().cloned(),
);
Ok(Box::new(arr) as ArrayRef)
}).collect::<PolarsResult<Vec<_>>>()?;

unsafe { Ok(ListChunked::from_chunks(self.name(), chunks)) }
}
}
56 changes: 17 additions & 39 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl ApplyExpr {
}
}

/// evaluates and flattens `Option<Series>` to `Series`.
fn eval_and_flatten(&self, inputs: &mut [Series]) -> PolarsResult<Series> {
self.function.call_udf(inputs).map(|opt_out| {
opt_out.unwrap_or_else(|| {
Expand Down Expand Up @@ -143,38 +144,25 @@ impl ApplyExpr {
ca.rename(&name);
self.finish_apply_groups(ac, ca)
}

/// Apply elementwise e.g. ignore the group/list indices
fn apply_single_flattened<'a>(
&self,
mut ac: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
// make sure the groups are updated because we are about to throw away
// the series' length information
let set_update_groups = match ac.update_groups {
UpdateGroups::WithSeriesLen => {
ac.groups();
true
let (s, aggregated) = match ac.agg_state() {
AggState::AggregatedList(s) => {
let ca = s.list().unwrap();
let out = ca.apply_to_inner(&|s| self.eval_and_flatten(&mut [s]))?;
(out.into_series(), true)
}
AggState::AggregatedFlat(s) => (self.eval_and_flatten(&mut [s.clone()])?, true),
AggState::NotAggregated(s) | AggState::Literal(s) => {
(self.eval_and_flatten(&mut [s.clone()])?, false)
}
UpdateGroups::WithSeriesLenOwned(_) => false,
UpdateGroups::No | UpdateGroups::WithGroupsLen => false,
};

if let UpdateGroups::WithSeriesLen = ac.update_groups {
ac.groups();
}

let input = ac.flat_naive().into_owned();
let input_len = input.len();
let s = self.eval_and_flatten(&mut [input])?;

check_map_output_len(input_len, s.len(), &self.expr)?;
ac.with_series(s, false, None)?;

if set_update_groups {
// The flat_naive orders by groups, so we must create new groups
// not by series length as we don't have an agg_list, but by original
// groups length
ac.update_groups = UpdateGroups::WithGroupsLen;
}
ac.with_series(s, aggregated, Some(&self.expr))?;
Ok(ac)
}
}
Expand Down Expand Up @@ -228,24 +216,14 @@ impl PhysicalExpr for ApplyExpr {
if self.inputs.len() == 1 {
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

match (state.has_overlapping_groups(), self.collect_groups) {
(_, ApplyOptions::ApplyList) => {
match self.collect_groups {
ApplyOptions::ApplyList => {
let s = self.eval_and_flatten(&mut [ac.aggregated()])?;
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
}
// - series is aggregated flat/reduction: sum/min/mean etc.
// - apply options is apply_flat -> elementwise
// we can simply apply the function in a vectorized manner.
(_, ApplyOptions::ApplyFlat) if ac.is_aggregated_flat() => {
let s = ac.aggregated();
let s = self.eval_and_flatten(&mut [s])?;
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
}
// overlapping groups always take this branch as explode/flat_naive bloats data size
(_, ApplyOptions::ApplyGroups) | (true, _) => self.apply_single_group_aware(ac),
(_, ApplyOptions::ApplyFlat) => self.apply_single_flattened(ac),
ApplyOptions::ApplyGroups => self.apply_single_group_aware(ac),
ApplyOptions::ApplyFlat => self.apply_single_flattened(ac),
}
} else {
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;
Expand Down
4 changes: 1 addition & 3 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ impl<'a> AggregationContext<'a> {
pub(crate) fn is_aggregated(&self) -> bool {
!self.is_not_aggregated()
}
pub(crate) fn is_aggregated_flat(&self) -> bool {
matches!(self.state, AggState::AggregatedFlat(_))
}

pub(crate) fn is_literal(&self) -> bool {
matches!(self.state, AggState::Literal(_))
}
Expand Down