Skip to content

Commit

Permalink
perf(rust, python): improve single argument elementwise expression pe… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 25, 2023
1 parent 36b16ea commit 5f26668
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 64 deletions.
49 changes: 27 additions & 22 deletions polars/polars-core/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,36 @@ impl ListChunked {
fld.coerce(DataType::List(Box::new(inner_dtype)))
}

/// Ignore the list indices and apply `func` to the inner type as `Series`.
pub fn apply_to_inner(
&self,
func: &dyn Fn(Series) -> PolarsResult<Series>,
) -> PolarsResult<ListChunked> {
let ca = self.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let elements = Series::try_from(("", arr.values().clone())).unwrap();

let expected_len = elements.len();
let out: Series = func(elements)?;
if out.len() != expected_len {
return Err(PolarsError::ComputeError(
"The function should apply only elementwise Instead it has removed elements".into(),
));
}
let out = out.rechunk();
let values = out.chunks()[0].clone();

let inner_dtype = LargeListArray::default_datatype(out.dtype().to_arrow());
let arr = LargeListArray::new(
inner_dtype,
(*arr.offsets()).clone(),
values,
arr.validity().cloned(),
);
unsafe { Ok(ListChunked::from_chunks(self.name(), vec![Box::new(arr)])) }
let inner_dtype = self.inner_dtype().to_arrow();

let chunks = self.downcast_iter().map(|arr| {
let elements = unsafe { Series::try_from_arrow_unchecked("", vec![(*arr.values()).clone()], &inner_dtype).unwrap() } ;

let expected_len = elements.len();
let out: Series = func(elements)?;
if out.len() != expected_len {
return Err(PolarsError::ComputeError(
"The function should apply only elementwise Instead it has removed elements".into(),
));
}
let out = out.rechunk();
let values = out.chunks()[0].clone();

let inner_dtype = LargeListArray::default_datatype(out.dtype().to_arrow());
let arr = LargeListArray::new(
inner_dtype,
(*arr.offsets()).clone(),
values,
arr.validity().cloned(),
);
Ok(Box::new(arr) as ArrayRef)
}).collect::<PolarsResult<Vec<_>>>()?;

unsafe { Ok(ListChunked::from_chunks(self.name(), chunks)) }
}
}
56 changes: 17 additions & 39 deletions polars/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl ApplyExpr {
}
}

/// evaluates and flattens `Option<Series>` to `Series`.
fn eval_and_flatten(&self, inputs: &mut [Series]) -> PolarsResult<Series> {
self.function.call_udf(inputs).map(|opt_out| {
opt_out.unwrap_or_else(|| {
Expand Down Expand Up @@ -143,38 +144,25 @@ impl ApplyExpr {
ca.rename(&name);
self.finish_apply_groups(ac, ca)
}

/// Apply elementwise e.g. ignore the group/list indices
fn apply_single_flattened<'a>(
&self,
mut ac: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
// make sure the groups are updated because we are about to throw away
// the series' length information
let set_update_groups = match ac.update_groups {
UpdateGroups::WithSeriesLen => {
ac.groups();
true
let (s, aggregated) = match ac.agg_state() {
AggState::AggregatedList(s) => {
let ca = s.list().unwrap();
let out = ca.apply_to_inner(&|s| self.eval_and_flatten(&mut [s]))?;
(out.into_series(), true)
}
AggState::AggregatedFlat(s) => (self.eval_and_flatten(&mut [s.clone()])?, true),
AggState::NotAggregated(s) | AggState::Literal(s) => {
(self.eval_and_flatten(&mut [s.clone()])?, false)
}
UpdateGroups::WithSeriesLenOwned(_) => false,
UpdateGroups::No | UpdateGroups::WithGroupsLen => false,
};

if let UpdateGroups::WithSeriesLen = ac.update_groups {
ac.groups();
}

let input = ac.flat_naive().into_owned();
let input_len = input.len();
let s = self.eval_and_flatten(&mut [input])?;

check_map_output_len(input_len, s.len(), &self.expr)?;
ac.with_series(s, false, None)?;

if set_update_groups {
// The flat_naive orders by groups, so we must create new groups
// not by series length as we don't have an agg_list, but by original
// groups length
ac.update_groups = UpdateGroups::WithGroupsLen;
}
ac.with_series(s, aggregated, Some(&self.expr))?;
Ok(ac)
}
}
Expand Down Expand Up @@ -228,24 +216,14 @@ impl PhysicalExpr for ApplyExpr {
if self.inputs.len() == 1 {
let mut ac = self.inputs[0].evaluate_on_groups(df, groups, state)?;

match (state.has_overlapping_groups(), self.collect_groups) {
(_, ApplyOptions::ApplyList) => {
match self.collect_groups {
ApplyOptions::ApplyList => {
let s = self.eval_and_flatten(&mut [ac.aggregated()])?;
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
}
// - series is aggregated flat/reduction: sum/min/mean etc.
// - apply options is apply_flat -> elementwise
// we can simply apply the function in a vectorized manner.
(_, ApplyOptions::ApplyFlat) if ac.is_aggregated_flat() => {
let s = ac.aggregated();
let s = self.eval_and_flatten(&mut [s])?;
ac.with_series(s, true, Some(&self.expr))?;
Ok(ac)
}
// overlapping groups always take this branch as explode/flat_naive bloats data size
(_, ApplyOptions::ApplyGroups) | (true, _) => self.apply_single_group_aware(ac),
(_, ApplyOptions::ApplyFlat) => self.apply_single_flattened(ac),
ApplyOptions::ApplyGroups => self.apply_single_group_aware(ac),
ApplyOptions::ApplyFlat => self.apply_single_flattened(ac),
}
} else {
let mut acs = self.prepare_multiple_inputs(df, groups, state)?;
Expand Down
4 changes: 1 addition & 3 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,7 @@ impl<'a> AggregationContext<'a> {
pub(crate) fn is_aggregated(&self) -> bool {
!self.is_not_aggregated()
}
pub(crate) fn is_aggregated_flat(&self) -> bool {
matches!(self.state, AggState::AggregatedFlat(_))
}

pub(crate) fn is_literal(&self) -> bool {
matches!(self.state, AggState::Literal(_))
}
Expand Down

0 comments on commit 5f26668

Please sign in to comment.