Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Migrate polars-expr AggregationContext to use Column #19736

Merged
154 changes: 132 additions & 22 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::borrow::Cow;

use arrow::bitmap::MutableBitmap;
use arrow::trusted_len::TrustMyLength;
use num_traits::{Num, NumCast};
use polars_error::PolarsResult;
use polars_utils::index::check_bounds;
Expand All @@ -8,6 +10,7 @@ pub use scalar::ScalarColumn;

use self::gather::check_bounds_ca;
use self::partitioned::PartitionedColumn;
use self::series::SeriesColumn;
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::metadata::{MetadataFlags, MetadataTrait};
use crate::datatypes::ReshapeDimension;
Expand All @@ -20,6 +23,7 @@ mod arithmetic;
mod compare;
mod partitioned;
mod scalar;
mod series;

/// A column within a [`DataFrame`].
///
Expand All @@ -35,7 +39,7 @@ mod scalar;
#[cfg_attr(feature = "serde", serde(from = "Series"))]
#[cfg_attr(feature = "serde", serde(into = "_SerdeSeries"))]
pub enum Column {
Series(Series),
Series(SeriesColumn),
Partitioned(PartitionedColumn),
Scalar(ScalarColumn),
}
Expand All @@ -47,12 +51,13 @@ pub trait IntoColumn: Sized {

impl Column {
#[inline]
#[track_caller]
pub fn new<T, Phantom>(name: PlSmallStr, values: T) -> Self
where
Phantom: ?Sized,
Series: NamedFrom<T, Phantom>,
{
Self::Series(NamedFrom::new(name, values))
Self::Series(SeriesColumn::new(NamedFrom::new(name, values)))
}

#[inline]
Expand Down Expand Up @@ -95,7 +100,7 @@ impl Column {
PartitionedColumn::new_empty(PlSmallStr::EMPTY, DataType::Null),
)
.take_materialized_series();
*self = Column::Series(series);
*self = Column::Series(series.into());
let Column::Series(s) = self else {
unreachable!();
};
Expand All @@ -107,7 +112,7 @@ impl Column {
ScalarColumn::new_empty(PlSmallStr::EMPTY, DataType::Null),
)
.take_materialized_series();
*self = Column::Series(series);
*self = Column::Series(series.into());
let Column::Series(s) = self else {
unreachable!();
};
Expand All @@ -121,7 +126,7 @@ impl Column {
#[inline]
pub fn take_materialized_series(self) -> Series {
match self {
Column::Series(s) => s,
Column::Series(s) => s.take(),
Column::Partitioned(s) => s.take_materialized_series(),
Column::Scalar(s) => s.take_materialized_series(),
}
Expand Down Expand Up @@ -586,31 +591,102 @@ impl Column {
}
}

/// General implementation for aggregation where a non-missing scalar would map to itself.
#[inline(always)]
#[cfg(any(feature = "algorithm_group_by", feature = "bitwise"))]
fn agg_with_unit_scalar(
&self,
groups: &GroupsProxy,
series_agg: impl Fn(&Series, &GroupsProxy) -> Series,
) -> Column {
match self {
Column::Series(s) => series_agg(s, groups).into_column(),
// @partition-opt
Column::Partitioned(s) => series_agg(s.as_materialized_series(), groups).into_column(),
Column::Scalar(s) => {
if s.is_empty() {
return self.clone();
}

// We utilize the aggregation on Series to see:
// 1. the output datatype of the aggregation
// 2. whether this aggregation is even defined
let series_aggregation = series_agg(
&s.as_single_value_series(),
&GroupsProxy::Slice {
// @NOTE: this group is always valid since s is non-empty.
groups: vec![[0, 1]],
rolling: false,
},
);

// If the aggregation is not defined, just return all nulls.
if series_aggregation.has_nulls() {
return Self::new_scalar(
series_aggregation.name().clone(),
Scalar::new(series_aggregation.dtype().clone(), AnyValue::Null),
groups.len(),
);
}

let mut scalar_col = s.resize(groups.len());
// The aggregation might change the type (e.g. mean changes int -> float), so we do
// a cast here to the output type.
if series_aggregation.dtype() != s.dtype() {
scalar_col = scalar_col.cast(series_aggregation.dtype()).unwrap();
}

let Some(first_empty_idx) = groups.iter().position(|g| g.is_empty()) else {
// Fast path: no empty groups. keep the scalar intact.
return scalar_col.into_column();
};

// All empty groups produce a *missing* or `null` value.
let mut validity = MutableBitmap::with_capacity(groups.len());
validity.extend_constant(first_empty_idx, true);
// SAFETY: We trust the length of this iterator.
let iter = unsafe {
TrustMyLength::new(
groups.iter().skip(first_empty_idx).map(|g| !g.is_empty()),
groups.len() - first_empty_idx,
)
};
validity.extend_from_trusted_len_iter(iter);
let validity = validity.freeze();

let mut s = scalar_col.take_materialized_series().rechunk();
// SAFETY: We perform a compute_len afterwards.
let chunks = unsafe { s.chunks_mut() };
chunks[0].with_validity(Some(validity));
s.compute_len();

s.into_column()
},
}
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_min(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_min(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_min(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_max(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_max(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_max(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_mean(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_mean(g) })
}

/// # Safety
Expand All @@ -627,17 +703,15 @@ impl Column {
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_first(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_first(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_first(g) })
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_last(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_last(g) })
}

/// # Safety
Expand Down Expand Up @@ -672,8 +746,7 @@ impl Column {
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_median(groups) }.into()
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_median(g) })
}

/// # Safety
Expand All @@ -689,7 +762,7 @@ impl Column {
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "algorithm_group_by")]
pub(crate) unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Self {
pub unsafe fn agg_std(&self, groups: &GroupsProxy, ddof: u8) -> Self {
// @scalar-opt
unsafe { self.as_materialized_series().agg_std(groups, ddof) }.into()
}
Expand All @@ -713,6 +786,30 @@ impl Column {
unsafe { self.as_materialized_series().agg_valid_count(groups) }.into()
}

/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_and(&self, groups: &GroupsProxy) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_and(g) })
}
/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_or(&self, groups: &GroupsProxy) -> Self {
self.agg_with_unit_scalar(groups, |s, g| unsafe { s.agg_or(g) })
}
/// # Safety
///
/// Does no bounds checks, groups must be correct.
#[cfg(feature = "bitwise")]
pub fn agg_xor(&self, groups: &GroupsProxy) -> Self {
// @partition-opt
// @scalar-opt
unsafe { self.as_materialized_series().agg_xor(groups) }.into()
}

pub fn full_null(name: PlSmallStr, size: usize, dtype: &DataType) -> Self {
Self::new_scalar(name, Scalar::new(dtype.clone(), AnyValue::Null), size)
}
Expand Down Expand Up @@ -877,6 +974,13 @@ impl Column {
}
}

/// Packs every element into a list.
pub fn as_list(&self) -> ListChunked {
// @scalar-opt
// @partition-opt
self.as_materialized_series().as_list()
}

pub fn is_sorted_flag(&self) -> IsSorted {
// @scalar-opt
self.as_materialized_series().is_sorted_flag()
Expand Down Expand Up @@ -1105,19 +1209,25 @@ impl Column {

pub fn try_add_owned(self, other: Self) -> PolarsResult<Self> {
match (self, other) {
(Column::Series(lhs), Column::Series(rhs)) => lhs.try_add_owned(rhs).map(Column::from),
(Column::Series(lhs), Column::Series(rhs)) => {
lhs.take().try_add_owned(rhs.take()).map(Column::from)
},
(lhs, rhs) => lhs + rhs,
}
}
pub fn try_sub_owned(self, other: Self) -> PolarsResult<Self> {
match (self, other) {
(Column::Series(lhs), Column::Series(rhs)) => lhs.try_sub_owned(rhs).map(Column::from),
(Column::Series(lhs), Column::Series(rhs)) => {
lhs.take().try_sub_owned(rhs.take()).map(Column::from)
},
(lhs, rhs) => lhs - rhs,
}
}
pub fn try_mul_owned(self, other: Self) -> PolarsResult<Self> {
match (self, other) {
(Column::Series(lhs), Column::Series(rhs)) => lhs.try_mul_owned(rhs).map(Column::from),
(Column::Series(lhs), Column::Series(rhs)) => {
lhs.take().try_mul_owned(rhs.take()).map(Column::from)
},
(lhs, rhs) => lhs * rhs,
}
}
Expand Down Expand Up @@ -1443,7 +1553,7 @@ impl From<Series> for Column {
return Self::Scalar(ScalarColumn::unit_scalar_from_series(series));
}

Self::Series(series)
Self::Series(SeriesColumn::new(series))
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/frame/column/partitioned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl PartitionedColumn {

fn _to_series(name: PlSmallStr, values: &Series, ends: &[IdxSize]) -> Series {
let dtype = values.dtype();
let mut column = Column::Series(Series::new_empty(name, dtype));
let mut column = Column::Series(Series::new_empty(name, dtype).into());

let mut prev_offset = 0;
for (i, &offset) in ends.iter().enumerate() {
Expand Down
5 changes: 5 additions & 0 deletions crates/polars-core/src/frame/column/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,11 @@ impl ScalarColumn {
self.scalar.update(AnyValue::Null);
self
}

pub fn map_scalar(&mut self, map_scalar: impl Fn(Scalar) -> Scalar) {
self.scalar = map_scalar(std::mem::take(&mut self.scalar));
self.materialized.take();
}
}

impl IntoColumn for ScalarColumn {
Expand Down
71 changes: 71 additions & 0 deletions crates/polars-core/src/frame/column/series.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::ops::{Deref, DerefMut};

use super::Series;

/// A very thin wrapper around [`Series`] that represents a [`Column`]ized version of [`Series`].
///
/// At the moment this just conditionally tracks where it was created so that materialization
/// problems can be tracked down.
#[derive(Debug, Clone)]
pub struct SeriesColumn {
inner: Series,

#[cfg(debug_assertions)]
materialized_at: Option<std::sync::Arc<std::backtrace::Backtrace>>,
}

impl SeriesColumn {
#[track_caller]
pub fn new(series: Series) -> Self {
Self {
inner: series,

#[cfg(debug_assertions)]
materialized_at: if std::env::var("POLARS_TRACK_SERIES_MATERIALIZATION").as_deref()
== Ok("1")
{
Some(std::sync::Arc::new(
std::backtrace::Backtrace::force_capture(),
))
} else {
None
},
}
}

pub fn materialized_at(&self) -> Option<&std::backtrace::Backtrace> {
#[cfg(debug_assertions)]
{
self.materialized_at.as_ref().map(|v| v.as_ref())
}

#[cfg(not(debug_assertions))]
None
}

pub fn take(self) -> Series {
self.inner
}
}

impl From<Series> for SeriesColumn {
#[track_caller]
#[inline(always)]
fn from(value: Series) -> Self {
Self::new(value)
}
}

impl Deref for SeriesColumn {
type Target = Series;

fn deref(&self) -> &Self::Target {
&self.inner
}
}

impl DerefMut for SeriesColumn {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
Loading
Loading