From f21e5c77b68561869fafdd7029cca2822f5db3b6 Mon Sep 17 00:00:00 2001 From: Wainberg Date: Mon, 15 Jan 2024 03:57:59 -0500 Subject: [PATCH] fix(python): support corr() for single-column DataFrames (#13728) Co-authored-by: Wainberg --- py-polars/polars/dataframe/frame.py | 5 ++++- py-polars/tests/unit/operations/test_statistics.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 4c5a5eb5ca82..b9863f13e943 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -10220,7 +10220,10 @@ def corr(self, **kwargs: Any) -> DataFrame: │ 1.0 ┆ -1.0 ┆ 1.0 │ └──────┴──────┴──────┘ """ - return DataFrame(np.corrcoef(self.to_numpy().T, **kwargs), schema=self.columns) + correlation_matrix = np.corrcoef(self.to_numpy(), rowvar=False, **kwargs) + if self.width == 1: + correlation_matrix = np.array([correlation_matrix]) + return DataFrame(correlation_matrix, schema=self.columns) def merge_sorted(self, other: DataFrame, key: str) -> DataFrame: """ diff --git a/py-polars/tests/unit/operations/test_statistics.py b/py-polars/tests/unit/operations/test_statistics.py index 73998c535402..91d5b388f08f 100644 --- a/py-polars/tests/unit/operations/test_statistics.py +++ b/py-polars/tests/unit/operations/test_statistics.py @@ -9,6 +9,11 @@ def test_corr() -> None: + df = pl.DataFrame({"a": [1, 2, 3]}) + result = df.corr() + expected = pl.DataFrame({"a": [1.0]}) + assert_frame_equal(result, expected) + df = pl.DataFrame( { "a": [1, 2, 4],