Skip to content

Commit

Permalink
fix(python,rust): Compute Spearman rank correlations using average ra… (
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj authored Jun 18, 2023
1 parent 05d1195 commit f8a1ee4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,25 +179,20 @@ fn spearman_rank_corr(s: &[Series], ddof: u8, propagate_nans: bool) -> PolarsRes
let a = a.drop_nulls();
let b = b.drop_nulls();

let a_idx = a.rank(
let a_rank = a.rank(
RankOptions {
method: RankMethod::Min,
method: RankMethod::Average,
..Default::default()
},
None,
);
let b_idx = b.rank(
let b_rank = b.rank(
RankOptions {
method: RankMethod::Min,
method: RankMethod::Average,
..Default::default()
},
None,
);
let a_idx = a_idx.idx().unwrap();
let b_idx = b_idx.idx().unwrap();

Ok(Series::new(
name,
&[polars_core::functions::pearson_corr_i(a_idx, b_idx, ddof)],
))
pearson_corr(&[a_rank, b_rank], ddof)
}
2 changes: 1 addition & 1 deletion py-polars/polars/functions/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ def corr(
┌─────┐
│ a │
│ --- │
f64
f32
╞═════╡
│ 0.5 │
└─────┘
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,25 @@ def test_spearman_corr() -> None:
assert np.isclose(out[1], -1.0)


def test_spearman_corr_ties() -> None:
"""In Spearman correlation, ranks are computed using the average method ."""
df = pl.DataFrame({"a": [1, 1, 1, 2, 3, 7, 4], "b": [4, 3, 2, 2, 4, 3, 1]})

result = df.select(
pl.corr("a", "b", method="spearman").alias("a1"),
pl.corr(pl.col("a").rank("min"), pl.col("b").rank("min")).alias("a2"),
pl.corr(pl.col("a").rank(), pl.col("b").rank()).alias("a3"),
)
expected = pl.DataFrame(
[
pl.Series("a1", [-0.19048483669757843], dtype=pl.Float32),
pl.Series("a2", [-0.17223653586587362], dtype=pl.Float64),
pl.Series("a3", [-0.19048483669757843], dtype=pl.Float32),
]
)
assert_frame_equal(result, expected)


def test_pearson_corr() -> None:
ldf = pl.LazyFrame(
{
Expand Down

0 comments on commit f8a1ee4

Please sign in to comment.