Skip to content

Commit

Permalink
Implement date mean/median
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Jun 6, 2024
1 parent 4d35be2 commit df32b05
Show file tree
Hide file tree
Showing 15 changed files with 256 additions and 188 deletions.
122 changes: 78 additions & 44 deletions crates/polars-core/src/frame/group_by/aggregations/dispatch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,58 @@ impl Series {
}
}

#[doc(hidden)]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

use DataType::*;
match s.dtype() {
Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
#[cfg(feature = "dtype-datetime")]
dt @ Datetime(_, _) => self
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
#[cfg(feature = "dtype-duration")]
dt @ Duration(_) => self
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
#[cfg(feature = "dtype-time")]
Time => self
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(&Time)
.unwrap(),
#[cfg(feature = "dtype-date")]
Date => (self
.to_physical_repr()
.agg_mean(groups)
.cast(&Float64)
.unwrap()
* (MS_IN_DAY as f64))
.cast(&Datetime(TimeUnit::Milliseconds, None))
.unwrap(),
_ => Series::full_null("", groups.len(), s.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_median(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
Expand All @@ -143,21 +195,38 @@ impl Series {
Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_median(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_median, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_) | Time) => s
dt @ Datetime(_, _) => self
.to_physical_repr()
.agg_median(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ Date => {
let ca = s.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_median, groups);
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
},
#[cfg(feature = "dtype-duration")]
dt @ Duration(_) => self
.to_physical_repr()
.agg_median(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
#[cfg(feature = "dtype-time")]
Time => self
.to_physical_repr()
.agg_median(groups)
.cast(&Int64)
.unwrap()
.cast(&Time)
.unwrap(),
#[cfg(feature = "dtype-date")]
Date => (self
.to_physical_repr()
.agg_median(groups)
.cast(&Float64)
.unwrap()
* (MS_IN_DAY as f64))
.cast(&Datetime(TimeUnit::Milliseconds, None))
.unwrap(),
_ => Series::full_null("", groups.len(), s.dtype()),
}
}
Expand Down Expand Up @@ -197,41 +266,6 @@ impl Series {
}
}

#[doc(hidden)]
pub unsafe fn agg_mean(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
let s = if groups.len() > 1 {
self.rechunk()
} else {
self.clone()
};

use DataType::*;
match s.dtype() {
Boolean => s.cast(&Float64).unwrap().agg_mean(groups),
Float32 => SeriesWrap(s.f32().unwrap().clone()).agg_mean(groups),
Float64 => SeriesWrap(s.f64().unwrap().clone()).agg_mean(groups),
dt if dt.is_numeric() => apply_method_physical_integer!(s, agg_mean, groups),
#[cfg(feature = "dtype-datetime")]
dt @ (Datetime(_, _) | Duration(_) | Time) => s
.to_physical_repr()
.agg_mean(groups)
.cast(&Int64)
.unwrap()
.cast(dt)
.unwrap(),
dt @ Date => {
let ca = s.to_physical_repr();
let physical_type = ca.dtype();
let s = apply_method_physical_integer!(ca, agg_mean, groups);
// back to physical and then
// back to logical type
s.cast(physical_type).unwrap().cast(dt).unwrap()
},
_ => Series::full_null("", groups.len(), s.dtype()),
}
}

#[doc(hidden)]
pub unsafe fn agg_last(&self, groups: &GroupsProxy) -> Series {
// Prevent a rechunk for every individual group.
Expand Down
13 changes: 8 additions & 5 deletions crates/polars-core/src/series/implementations/date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,14 @@ impl SeriesTrait for SeriesWrap<DateChunked> {
}

fn median_reduce(&self) -> PolarsResult<Scalar> {
let av = AnyValue::from(self.median().map(|v| v as i64))
.cast(self.dtype())
.into_static()
.unwrap();
Ok(Scalar::new(self.dtype().clone(), av))
let av: AnyValue = self
.median()
.map(|v| (v * (MS_IN_DAY as f64)) as i64)
.into();
Ok(Scalar::new(
DataType::Datetime(TimeUnit::Milliseconds, None),
av,
))
}

fn clone_inner(&self) -> Arc<dyn SeriesTrait> {
Expand Down
21 changes: 20 additions & 1 deletion crates/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,26 @@ impl Series {
let val = self.mean();
Scalar::new(DataType::Float64, val.into())
},
dt if dt.is_temporal() => {
#[cfg(feature = "dtype-date")]
DataType::Date => {
let val = self.mean().map(|v| (v * MS_IN_DAY as f64) as i64);
let av: AnyValue = val.into();
Scalar::new(DataType::Datetime(TimeUnit::Milliseconds, None), av)
},
#[cfg(feature = "dtype-datetime")]
dt @ DataType::Datetime(_, _) => {
let val = self.mean().map(|v| v as i64);
let av: AnyValue = val.into();
Scalar::new(dt.clone(), av)
},
#[cfg(feature = "dtype-duration")]
dt @ DataType::Duration(_) => {
let val = self.mean().map(|v| v as i64);
let av: AnyValue = val.into();
Scalar::new(dt.clone(), av)
},
#[cfg(feature = "dtype-time")]
dt @ DataType::Time => {
let val = self.mean().map(|v| v as i64);
let av: AnyValue = val.into();
Scalar::new(dt.clone(), av)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ where
#[cfg(feature = "dtype-categorical")]
if matches!(
logical_dtype,
DataType::Categorical(_, _) | DataType::Enum(_, _)
DataType::Categorical(_, _) | DataType::Enum(_, _) | DataType::Date
) {
return (
logical_dtype.clone(),
Expand Down
10 changes: 8 additions & 2 deletions crates/polars-plan/src/logical_plan/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,19 @@ impl AExpr {
Median(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
float_type(&mut field);
match field.dtype {
Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
_ => float_type(&mut field),
}
Ok(field)
},
Mean(expr) => {
let mut field =
arena.get(*expr).to_field(schema, Context::Default, arena)?;
float_type(&mut field);
match field.dtype {
Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
_ => float_type(&mut field),
}
Ok(field)
},
Implode(expr) => {
Expand Down
22 changes: 2 additions & 20 deletions crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,16 +471,7 @@ pub fn to_alp_impl(
&input_schema,
),
StatsFunction::Mean => stats_helper(
|dt| {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|dt| dt.is_numeric() || dt.is_temporal() || dt == &DataType::Boolean,
|name| col(name).mean(),
&input_schema,
),
Expand All @@ -500,16 +491,7 @@ pub fn to_alp_impl(
stats_helper(|dt| dt.is_ord(), |name| col(name).max(), &input_schema)
},
StatsFunction::Median => stats_helper(
|dt| {
dt.is_numeric()
|| matches!(
dt,
DataType::Boolean
| DataType::Duration(_)
| DataType::Datetime(_, _)
| DataType::Time
)
},
|dt| dt.is_numeric() || dt.is_temporal() || dt == &DataType::Boolean,
|name| col(name).median(),
&input_schema,
),
Expand Down
66 changes: 33 additions & 33 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4411,21 +4411,21 @@ def describe(
>>> df.describe()
shape: (9, 7)
┌────────────┬──────────┬──────────┬──────────┬──────┬────────────┬──────────┐
│ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ str ┆ str ┆ str │
╞════════════╪══════════╪══════════╪══════════╪══════╪════════════╪══════════╡
│ count ┆ 3.0 ┆ 2.0 ┆ 3.0 ┆ 3 ┆ 3 ┆ 3 │
│ null_count ┆ 0.0 ┆ 1.0 ┆ 0.0 ┆ 0 ┆ 0 ┆ 0 │
│ mean ┆ 2.266667 ┆ 45.0 ┆ 0.666667 ┆ null ┆ 2021-07-02 ┆ 16:07:10 │
│ std ┆ 1.101514 ┆ 7.071068 ┆ null ┆ null ┆ null ┆ null │
│ min ┆ 1.0 ┆ 40.0 ┆ 0.0 ┆ xx ┆ 2020-01-01 ┆ 10:20:30 │
│ 25% ┆ 2.8 ┆ 40.0 ┆ null ┆ null ┆ 2021-07-05 ┆ 14:45:50 │
│ 50% ┆ 2.8 ┆ 50.0 ┆ null ┆ null ┆ 2021-07-05 ┆ 14:45:50 │
│ 75% ┆ 3.0 ┆ 50.0 ┆ null ┆ null ┆ 2022-12-31 ┆ 23:15:10 │
│ max ┆ 3.0 ┆ 50.0 ┆ 1.0 ┆ zz ┆ 2022-12-31 ┆ 23:15:10 │
└────────────┴──────────┴──────────┴──────────┴──────┴────────────┴──────────┘
┌────────────┬──────────┬──────────┬──────────┬──────┬─────────────────────┬──────────┐
│ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ str ┆ str ┆ str │
╞════════════╪══════════╪══════════╪══════════╪══════╪═════════════════════╪══════════╡
│ count ┆ 3.0 ┆ 2.0 ┆ 3.0 ┆ 3 ┆ 3 ┆ 3 │
│ null_count ┆ 0.0 ┆ 1.0 ┆ 0.0 ┆ 0 ┆ 0 ┆ 0 │
│ mean ┆ 2.266667 ┆ 45.0 ┆ 0.666667 ┆ null ┆ 2021-07-02 16:00:00 ┆ 16:07:10 │
│ std ┆ 1.101514 ┆ 7.071068 ┆ null ┆ null ┆ null ┆ null │
│ min ┆ 1.0 ┆ 40.0 ┆ 0.0 ┆ xx ┆ 2020-01-01 ┆ 10:20:30 │
│ 25% ┆ 2.8 ┆ 40.0 ┆ null ┆ null ┆ 2021-07-05 ┆ 14:45:50 │
│ 50% ┆ 2.8 ┆ 50.0 ┆ null ┆ null ┆ 2021-07-05 ┆ 14:45:50 │
│ 75% ┆ 3.0 ┆ 50.0 ┆ null ┆ null ┆ 2022-12-31 ┆ 23:15:10 │
│ max ┆ 3.0 ┆ 50.0 ┆ 1.0 ┆ zz ┆ 2022-12-31 ┆ 23:15:10 │
└────────────┴──────────┴──────────┴──────────┴──────┴─────────────────────┴──────────┘
Customize which percentiles are displayed, applying linear interpolation:
Expand All @@ -4435,24 +4435,24 @@ def describe(
... interpolation="linear",
... )
shape: (11, 7)
┌────────────┬──────────┬──────────┬──────────┬──────┬────────────┬──────────┐
│ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ str ┆ str ┆ str │
╞════════════╪══════════╪══════════╪══════════╪══════╪════════════╪══════════╡
│ count ┆ 3.0 ┆ 2.0 ┆ 3.0 ┆ 3 ┆ 3 ┆ 3 │
│ null_count ┆ 0.0 ┆ 1.0 ┆ 0.0 ┆ 0 ┆ 0 ┆ 0 │
│ mean ┆ 2.266667 ┆ 45.0 ┆ 0.666667 ┆ null ┆ 2021-07-02 ┆ 16:07:10 │
│ std ┆ 1.101514 ┆ 7.071068 ┆ null ┆ null ┆ null ┆ null │
│ min ┆ 1.0 ┆ 40.0 ┆ 0.0 ┆ xx ┆ 2020-01-01 ┆ 10:20:30 │
│ 10% ┆ 1.36 ┆ 41.0 ┆ null ┆ null ┆ 2020-04-20 ┆ 11:13:34 │
│ 30% ┆ 2.08 ┆ 43.0 ┆ null ┆ null ┆ 2020-11-26 ┆ 12:59:42 │
│ 50% ┆ 2.8 ┆ 45.0 ┆ null ┆ null ┆ 2021-07-05 ┆ 14:45:50 │
│ 70% ┆ 2.88 ┆ 47.0 ┆ null ┆ null ┆ 2022-02-07 ┆ 18:09:34 │
│ 90% ┆ 2.96 ┆ 49.0 ┆ null ┆ null ┆ 2022-09-13 ┆ 21:33:18 │
│ max ┆ 3.0 ┆ 50.0 ┆ 1.0 ┆ zz ┆ 2022-12-31 ┆ 23:15:10 │
└────────────┴──────────┴──────────┴──────────┴──────┴────────────┴──────────┘
"""
┌────────────┬──────────┬──────────┬──────────┬──────┬─────────────────────┬──────────┐
│ statistic ┆ float ┆ int ┆ bool ┆ str ┆ date ┆ time │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ f64 ┆ f64 ┆ f64 ┆ str ┆ str ┆ str │
╞════════════╪══════════╪══════════╪══════════╪══════╪═════════════════════╪══════════╡
│ count ┆ 3.0 ┆ 2.0 ┆ 3.0 ┆ 3 ┆ 3 ┆ 3 │
│ null_count ┆ 0.0 ┆ 1.0 ┆ 0.0 ┆ 0 ┆ 0 ┆ 0 │
│ mean ┆ 2.266667 ┆ 45.0 ┆ 0.666667 ┆ null ┆ 2021-07-02 16:00:00 ┆ 16:07:10 │
│ std ┆ 1.101514 ┆ 7.071068 ┆ null ┆ null ┆ null ┆ null │
│ min ┆ 1.0 ┆ 40.0 ┆ 0.0 ┆ xx ┆ 2020-01-01 ┆ 10:20:30 │
│ 10% ┆ 1.36 ┆ 41.0 ┆ null ┆ null ┆ 2020-04-20 ┆ 11:13:34 │
│ 30% ┆ 2.08 ┆ 43.0 ┆ null ┆ null ┆ 2020-11-26 ┆ 12:59:42 │
│ 50% ┆ 2.8 ┆ 45.0 ┆ null ┆ null ┆ 2021-07-05 ┆ 14:45:50 │
│ 70% ┆ 2.88 ┆ 47.0 ┆ null ┆ null ┆ 2022-02-07 ┆ 18:09:34 │
│ 90% ┆ 2.96 ┆ 49.0 ┆ null ┆ null ┆ 2022-09-13 ┆ 21:33:18 │
│ max ┆ 3.0 ┆ 50.0 ┆ 1.0 ┆ zz ┆ 2022-12-31 ┆ 23:15:10 │
└────────────┴──────────┴──────────┴──────────┴──────┴─────────────────────┴──────────┘
""" # noqa: W505
if not self.columns:
msg = "cannot describe a DataFrame that has no columns"
raise TypeError(msg)
Expand Down
Loading

0 comments on commit df32b05

Please sign in to comment.