Skip to content

Commit

Permalink
feat: raise if both closed and by are passed to rolling_* aggre…
Browse files Browse the repository at this point in the history
…gations (#15108)
  • Loading branch information
MarcoGorelli authored Mar 18, 2024
1 parent c0fada8 commit e21bd27
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 53 deletions.
16 changes: 8 additions & 8 deletions crates/polars-plan/src/dsl/function_expr/rolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ fn convert<'a>(
by: Some(by_values),
tu: Some(tu),
tz: tz.as_ref(),
closed_window: options.closed_window,
closed_window: options.closed_window.or(Some(ClosedWindow::Right)),
fn_params: options.fn_params.clone(),
};

Expand All @@ -130,39 +130,39 @@ fn convert<'a>(
}

pub(super) fn rolling_min(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_min(options.clone().into())
s.rolling_min(options.clone().try_into()?)
}

pub(super) fn rolling_min_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
convert(|options| s[0].rolling_min(options), s, "rolling_min")(options)
}

pub(super) fn rolling_max(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_max(options.clone().into())
s.rolling_max(options.clone().try_into()?)
}

pub(super) fn rolling_max_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
convert(|options| s[0].rolling_max(options), s, "rolling_max")(options)
}

pub(super) fn rolling_mean(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_mean(options.clone().into())
s.rolling_mean(options.clone().try_into()?)
}

pub(super) fn rolling_mean_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
convert(|options| s[0].rolling_mean(options), s, "rolling_mean")(options)
}

pub(super) fn rolling_sum(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_sum(options.clone().into())
s.rolling_sum(options.clone().try_into()?)
}

pub(super) fn rolling_sum_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
convert(|options| s[0].rolling_sum(options), s, "rolling_sum")(options)
}

pub(super) fn rolling_quantile(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_quantile(options.clone().into())
s.rolling_quantile(options.clone().try_into()?)
}

pub(super) fn rolling_quantile_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
Expand All @@ -174,15 +174,15 @@ pub(super) fn rolling_quantile_by(s: &[Series], options: RollingOptions) -> Pola
}

pub(super) fn rolling_var(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_var(options.clone().into())
s.rolling_var(options.clone().try_into()?)
}

pub(super) fn rolling_var_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
convert(|options| s[0].rolling_var(options), s, "rolling_var")(options)
}

pub(super) fn rolling_std(s: &Series, options: RollingOptions) -> PolarsResult<Series> {
s.rolling_std(options.clone().into())
s.rolling_std(options.clone().try_into()?)
}

pub(super) fn rolling_std_by(s: &[Series], options: RollingOptions) -> PolarsResult<Series> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ where
let arr = ca.downcast_iter().next().unwrap();
// "5i" is a window size of 5, e.g. fixed
let arr = if options.window_size.parsed_int {
let options: RollingOptionsFixedWindow = options.into();
check_input(options.window_size, options.min_periods)?;
let options: RollingOptionsFixedWindow = options.try_into()?;

Ok(match ca.null_count() {
0 => rolling_agg_fn(
Expand Down
51 changes: 24 additions & 27 deletions crates/polars-time/src/chunkedarray/rolling_window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,22 @@ pub struct RollingOptionsImpl<'a> {
pub fn_params: DynArgs,
}

impl From<RollingOptions> for RollingOptionsImpl<'static> {
fn from(options: RollingOptions) -> Self {
impl TryFrom<RollingOptions> for RollingOptionsImpl<'static> {
type Error = PolarsError;

fn try_from(options: RollingOptions) -> PolarsResult<Self> {
let window_size = options.window_size;
assert!(
window_size.parsed_int,
"should be fixed integer window size at this point"
);
polars_ensure!(
options.closed_window.is_none(),
InvalidOperation: "`closed_window` is not supported for fixed window size rolling aggregations, \
consider using DataFrame.rolling for greater flexibility",
);

RollingOptionsImpl {
Ok(RollingOptionsImpl {
window_size,
min_periods: options.min_periods,
weights: options.weights,
Expand All @@ -98,25 +105,7 @@ impl From<RollingOptions> for RollingOptionsImpl<'static> {
tz: None,
closed_window: None,
fn_params: options.fn_params,
}
}
}

impl From<RollingOptions> for RollingOptionsFixedWindow {
fn from(options: RollingOptions) -> Self {
let window_size = options.window_size;
assert!(
window_size.parsed_int,
"should be fixed integer window size at this point"
);

RollingOptionsFixedWindow {
window_size: window_size.nanoseconds() as usize,
min_periods: options.min_periods,
weights: options.weights,
center: options.center,
fn_params: options.fn_params,
}
})
}
}

Expand All @@ -136,21 +125,29 @@ impl Default for RollingOptionsImpl<'static> {
}
}

impl<'a> From<RollingOptionsImpl<'a>> for RollingOptionsFixedWindow {
fn from(options: RollingOptionsImpl<'a>) -> Self {
impl<'a> TryFrom<RollingOptionsImpl<'a>> for RollingOptionsFixedWindow {
type Error = PolarsError;
fn try_from(options: RollingOptionsImpl<'a>) -> PolarsResult<Self> {
let window_size = options.window_size;
assert!(
window_size.parsed_int,
"should be fixed integer window size at this point"
);
polars_ensure!(
options.closed_window.is_none(),
InvalidOperation: "`closed_window` is not supported for fixed window size rolling aggregations, \
consider using DataFrame.rolling for greater flexibility",
);
let window_size = window_size.nanoseconds() as usize;
check_input(window_size, options.min_periods)?;

RollingOptionsFixedWindow {
window_size: window_size.nanoseconds() as usize,
Ok(RollingOptionsFixedWindow {
window_size,
min_periods: options.min_periods,
weights: options.weights,
center: options.center,
fn_params: options.fn_params,
}
})
}
}

Expand Down
32 changes: 16 additions & 16 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5717,7 +5717,7 @@ def rolling_min(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
warn_if_unsorted: bool = True,
) -> Self:
"""
Expand Down Expand Up @@ -5789,7 +5789,7 @@ def rolling_min(
results will not be correct.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Expand Down Expand Up @@ -5929,7 +5929,7 @@ def rolling_max(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
warn_if_unsorted: bool = True,
) -> Self:
"""
Expand Down Expand Up @@ -5997,7 +5997,7 @@ def rolling_max(
be of dtype Datetime or Date.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Expand Down Expand Up @@ -6166,7 +6166,7 @@ def rolling_mean(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
warn_if_unsorted: bool = True,
) -> Self:
"""
Expand Down Expand Up @@ -6238,7 +6238,7 @@ def rolling_mean(
results will not be correct.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Expand Down Expand Up @@ -6413,7 +6413,7 @@ def rolling_sum(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
warn_if_unsorted: bool = True,
) -> Self:
"""
Expand Down Expand Up @@ -6481,7 +6481,7 @@ def rolling_sum(
of dtype `{Date, Datetime}`
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Expand Down Expand Up @@ -6650,7 +6650,7 @@ def rolling_std(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
ddof: int = 1,
warn_if_unsorted: bool = True,
) -> Self:
Expand Down Expand Up @@ -6720,7 +6720,7 @@ def rolling_std(
results will not be correct.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
ddof
"Delta Degrees of Freedom": The divisor for a length N window is N - ddof
warn_if_unsorted
Expand Down Expand Up @@ -6898,7 +6898,7 @@ def rolling_var(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
ddof: int = 1,
warn_if_unsorted: bool = True,
) -> Self:
Expand Down Expand Up @@ -6967,7 +6967,7 @@ def rolling_var(
results will not be correct.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
ddof
"Delta Degrees of Freedom": The divisor for a length N window is N - ddof
warn_if_unsorted
Expand Down Expand Up @@ -7145,7 +7145,7 @@ def rolling_median(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
warn_if_unsorted: bool = True,
) -> Self:
"""
Expand Down Expand Up @@ -7214,7 +7214,7 @@ def rolling_median(
results will not be correct.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Expand Down Expand Up @@ -7305,7 +7305,7 @@ def rolling_quantile(
*,
center: bool = False,
by: str | None = None,
closed: ClosedInterval = "right",
closed: ClosedInterval | None = None,
warn_if_unsorted: bool = True,
) -> Self:
"""
Expand Down Expand Up @@ -7377,7 +7377,7 @@ def rolling_quantile(
results will not be correct.
closed : {'left', 'right', 'both', 'none'}
Define which sides of the temporal interval are closed (inclusive); only
applicable if `by` has been set.
applicable if `by` has been set (in which case, it defaults to `'right'`).
warn_if_unsorted
Warn if data is not known to be sorted by `by` column (if passed).
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,16 @@ def test_rolling_infinity() -> None:
assert_series_equal(s, expected)


def test_rolling_invalid_closed_option() -> None:
df = pl.DataFrame(
{"a": [4, 5, 6], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]}
).sort("a", "b")
with pytest.raises(InvalidOperationError, match="consider using DataFrame.rolling"):
df.with_columns(pl.col("a").rolling_sum(2, closed="left"))
with pytest.raises(InvalidOperationError, match="consider using DataFrame.rolling"):
df.with_columns(pl.col("a").rolling_sum(2, by="b", closed="left"))


def test_rolling_extrema() -> None:
# sorted data and nulls flags trigger different kernels
df = (
Expand Down

0 comments on commit e21bd27

Please sign in to comment.