diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 83c8292918f4..d998af0c7b6a 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2562,7 +2562,12 @@ impl DataFrame { pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { match self.columns.len() { 0 => Ok(None), - 1 => Ok(Some(self.columns[0].clone())), + 1 => Ok(Some(match self.columns[0].dtype() { + dt if dt != &DataType::Float32 && (dt.is_numeric() || dt == &DataType::Boolean) => { + self.columns[0].cast(&DataType::Float64)? + }, + _ => self.columns[0].clone(), + })), _ => { let columns = self .columns diff --git a/py-polars/tests/unit/operations/test_aggregations.py b/py-polars/tests/unit/operations/test_aggregations.py index c0c02ae874ce..64ea9551dffe 100644 --- a/py-polars/tests/unit/operations/test_aggregations.py +++ b/py-polars/tests/unit/operations/test_aggregations.py @@ -12,6 +12,8 @@ if TYPE_CHECKING: import numpy.typing as npt + from polars.type_aliases import PolarsDataType + def test_quantile_expr_input() -> None: df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0.0, 0.0, 0.3, 0.2, 0.0]}) @@ -471,3 +473,60 @@ def test_grouping_hash_14749() -> None: .select(pl.col("x").max().over("grp"))["x"] .value_counts() ).to_dict(as_series=False) == {"x": [3], "count": [1004]} + + +@pytest.mark.parametrize( + ("in_dtype", "out_dtype"), + [ + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_horizontal_mean_single_column( + in_dtype: PolarsDataType, + out_dtype: PolarsDataType, +) -> None: + out = ( + pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)}) + .select(pl.mean_horizontal(pl.all())) + .collect() + ) + + assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0], dtype=out_dtype)})) + + +def test_horizontal_mean_in_groupby_15115() -> None: + nbr_records = 1000 + out = ( + pl.LazyFrame( + { + "w": [None, "one", "two", "three"] * nbr_records, + "x": [None, None, "two", "three"] * nbr_records, + "y": [None, None, None, "three"] * nbr_records, + "z": [None, None, None, None] * nbr_records, + } + ) + .select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null")) + .group_by("mean_null") + .len() + .sort(by="mean_null") + .collect() + ) + assert_frame_equal( + out, + pl.DataFrame( + { + "mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64), + "len": pl.Series([nbr_records] * 4, dtype=pl.UInt32), + } + ), + ) diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index b0785dad8987..12ff538515e7 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -2,13 +2,16 @@ from collections import OrderedDict from datetime import date, timedelta -from typing import Any, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Iterator, Mapping import pytest import polars as pl from polars.testing import assert_frame_equal, assert_series_equal +if TYPE_CHECKING: + from polars.type_aliases import PolarsDataType + class CustomSchema(Mapping[str, Any]): """Dummy schema object for testing compatibility with Mapping.""" @@ -640,6 +643,33 @@ def test_schema_boolean_sum_horizontal() -> None: assert lf.schema == OrderedDict([("a", pl.UInt32)]) +@pytest.mark.parametrize( + ("in_dtype", "out_dtype"), + [ + (pl.Boolean, pl.Float64), + (pl.UInt8, pl.Float64), + (pl.UInt16, pl.Float64), + (pl.UInt32, pl.Float64), + (pl.UInt64, pl.Float64), + (pl.Int8, pl.Float64), + (pl.Int16, pl.Float64), + (pl.Int32, pl.Float64), + (pl.Int64, pl.Float64), + (pl.Float32, pl.Float32), + (pl.Float64, pl.Float64), + ], +) +def test_schema_mean_horizontal_single_column( + in_dtype: PolarsDataType, + out_dtype: PolarsDataType, +) -> None: + lf = pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)}).select( + pl.mean_horizontal(pl.all()) + ) + + assert lf.schema == OrderedDict([("a", out_dtype)]) + + def test_struct_alias_prune_15401() -> None: df = pl.DataFrame({"a": []}, schema={"a": pl.Struct({"b": pl.Int8})}) assert df.select(pl.col("a").alias("c").struct.field("b")).columns == ["b"]