diff --git a/polars/polars-lazy/polars-plan/src/dsl/function_expr/correlation.rs b/polars/polars-lazy/polars-plan/src/dsl/function_expr/correlation.rs index 5e585a14a7b3..1a1b36c310d6 100644 --- a/polars/polars-lazy/polars-plan/src/dsl/function_expr/correlation.rs +++ b/polars/polars-lazy/polars-plan/src/dsl/function_expr/correlation.rs @@ -179,25 +179,20 @@ fn spearman_rank_corr(s: &[Series], ddof: u8, propagate_nans: bool) -> PolarsRes let a = a.drop_nulls(); let b = b.drop_nulls(); - let a_idx = a.rank( + let a_rank = a.rank( RankOptions { - method: RankMethod::Min, + method: RankMethod::Average, ..Default::default() }, None, ); - let b_idx = b.rank( + let b_rank = b.rank( RankOptions { - method: RankMethod::Min, + method: RankMethod::Average, ..Default::default() }, None, ); - let a_idx = a_idx.idx().unwrap(); - let b_idx = b_idx.idx().unwrap(); - Ok(Series::new( - name, - &[polars_core::functions::pearson_corr_i(a_idx, b_idx, ddof)], - )) + pearson_corr(&[a_rank, b_rank], ddof) } diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index cc3998c7d035..135587b5fd06 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1421,7 +1421,7 @@ def corr( ┌─────┐ │ a │ │ --- │ - │ f64 │ + │ f32 │ ╞═════╡ │ 0.5 │ └─────┘ diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 0c8b231775f8..6065248671f5 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -979,6 +979,25 @@ def test_spearman_corr() -> None: assert np.isclose(out[1], -1.0) +def test_spearman_corr_ties() -> None: + """In Spearman correlation, ranks are computed using the average method .""" + df = pl.DataFrame({"a": [1, 1, 1, 2, 3, 7, 4], "b": [4, 3, 2, 2, 4, 3, 1]}) + + result = df.select( + pl.corr("a", "b", method="spearman").alias("a1"), + pl.corr(pl.col("a").rank("min"), pl.col("b").rank("min")).alias("a2"), + pl.corr(pl.col("a").rank(), pl.col("b").rank()).alias("a3"), + ) + expected = pl.DataFrame( + [ + pl.Series("a1", [-0.19048483669757843], dtype=pl.Float32), + pl.Series("a2", [-0.17223653586587362], dtype=pl.Float64), + pl.Series("a3", [-0.19048483669757843], dtype=pl.Float32), + ] + ) + assert_frame_equal(result, expected) + + def test_pearson_corr() -> None: ldf = pl.LazyFrame( {