From bb4fbe99415c19525584f700454343c14c340b05 Mon Sep 17 00:00:00 2001 From: andrew Date: Sat, 17 Aug 2024 20:05:02 -0400 Subject: [PATCH 1/2] Compute joint null mask before calling rolling corr/cov stats --- .../src/dsl/functions/correlation.rs | 20 +++++++++---- .../unit/operations/rolling/test_rolling.py | 30 +++++++++++++++++++ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/crates/polars-plan/src/dsl/functions/correlation.rs b/crates/polars-plan/src/dsl/functions/correlation.rs index bb0fc5aa3cf1..dd7521ad20a9 100644 --- a/crates/polars-plan/src/dsl/functions/correlation.rs +++ b/crates/polars-plan/src/dsl/functions/correlation.rs @@ -79,11 +79,15 @@ pub fn rolling_corr(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; + let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) + .then(lit(1.0)) + .otherwise(lit(Null {})); + let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = x.clone().rolling_mean(rolling_options.clone()); - let mean_y = y.clone().rolling_mean(rolling_options.clone()); - let var_x = x.clone().rolling_var(rolling_options.clone()); - let var_y = y.clone().rolling_var(rolling_options); + let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); + let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); + let var_x = (x.clone() * non_null_mask.clone()).rolling_var(rolling_options.clone()); + let var_y = (y.clone() * non_null_mask.clone()).rolling_var(rolling_options); let rolling_options_count = RollingOptionsFixedWindow { window_size: options.window_size as usize, @@ -110,9 +114,13 @@ pub fn rolling_cov(x: Expr, y: Expr, options: RollingCovOptions) -> Expr { ..Default::default() }; + let non_null_mask = when(x.clone().is_not_null().and(y.clone().is_not_null())) + .then(lit(1.0)) + .otherwise(lit(Null {})); + let mean_x_y = (x.clone() * y.clone()).rolling_mean(rolling_options.clone()); - let mean_x = x.clone().rolling_mean(rolling_options.clone()); - let mean_y = y.clone().rolling_mean(rolling_options); + let mean_x = (x.clone() * non_null_mask.clone()).rolling_mean(rolling_options.clone()); + let mean_y = (y.clone() * non_null_mask.clone()).rolling_mean(rolling_options); let rolling_options_count = RollingOptionsFixedWindow { window_size: options.window_size as usize, min_periods: 0, diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 8e5bbfd69bd1..8139d3786b42 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -589,6 +589,36 @@ def test_rolling_cov_corr() -> None: assert res["corr"][:2] == [None] * 2 +def test_rolling_cov_corr_nulls() -> None: + df1 = pl.DataFrame( + {"a": [1.06, 1.07, 0.93, 0.78, 0.85], "lag_a": [1.0, 1.06, 1.07, 0.93, 0.78]} + ) + df2 = pl.DataFrame( + { + "a": [1.0, 1.06, 1.07, 0.93, 0.78, 0.85], + "lag_a": [None, 1.0, 1.06, 1.07, 0.93, 0.78], + } + ) + + val_1 = df1.select( + pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) + ).item() + val_2 = df2.select( + pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) + ).item() + + assert val_1 == val_2 + + val_1 = df1.select( + pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) + ).item() + val_2 = df2.select( + pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) + ).item() + + assert val_1 == val_2 + + @pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) def test_rolling_empty_window_9406(time_unit: TimeUnit) -> None: datecol = pl.Series( From ff1a3180c278ab975d8ebe0dfb78d3c5957b2f6a Mon Sep 17 00:00:00 2001 From: andrew Date: Sun, 18 Aug 2024 09:35:43 -0400 Subject: [PATCH 2/2] explicitly test result frames --- .../unit/operations/rolling/test_rolling.py | 28 ++++++++++++------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/py-polars/tests/unit/operations/rolling/test_rolling.py b/py-polars/tests/unit/operations/rolling/test_rolling.py index 8139d3786b42..d934683d645f 100644 --- a/py-polars/tests/unit/operations/rolling/test_rolling.py +++ b/py-polars/tests/unit/operations/rolling/test_rolling.py @@ -601,22 +601,30 @@ def test_rolling_cov_corr_nulls() -> None: ) val_1 = df1.select( - pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) - ).item() + pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) val_2 = df2.select( - pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) - ).item() + pl.rolling_corr("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) + + df1_expected = pl.DataFrame({"a": [None, None, None, None, 0.62204709]}) + df2_expected = pl.DataFrame({"a": [None, None, None, None, None, 0.62204709]}) - assert val_1 == val_2 + assert_frame_equal(val_1, df1_expected, atol=0.0000001) + assert_frame_equal(val_2, df2_expected, atol=0.0000001) val_1 = df1.select( - pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) - ).item() + pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) val_2 = df2.select( - pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1).tail(1) - ).item() + pl.rolling_cov("a", "lag_a", window_size=10, min_periods=5, ddof=1) + ) + + df1_expected = pl.DataFrame({"a": [None, None, None, None, 0.009445]}) + df2_expected = pl.DataFrame({"a": [None, None, None, None, None, 0.009445]}) - assert val_1 == val_2 + assert_frame_equal(val_1, df1_expected, atol=0.0000001) + assert_frame_equal(val_2, df2_expected, atol=0.0000001) @pytest.mark.parametrize("time_unit", ["ms", "us", "ns"])