Skip to content

Commit

Permalink
Cast bool to f64 during mean_horizontal
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Mar 18, 2024
1 parent 9f7ec49 commit 82cb292
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 2 deletions.
7 changes: 6 additions & 1 deletion crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2561,7 +2561,12 @@ impl DataFrame {
pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Series>> {
match self.columns.len() {
0 => Ok(None),
1 => Ok(Some(self.columns[0].clone())),
1 => Ok(Some(match self.columns[0].dtype() {
dt if dt != &DataType::Float32 && (dt.is_numeric() || dt == &DataType::Boolean) => {
self.columns[0].cast(&DataType::Float64)?
},
_ => self.columns[0].clone(),
})),
_ => {
let columns = self
.columns
Expand Down
59 changes: 59 additions & 0 deletions py-polars/tests/unit/operations/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
if TYPE_CHECKING:
import numpy.typing as npt

from polars.type_aliases import PolarsDataType


def test_quantile_expr_input() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [0, 0, 0.3, 0.2, 0]})
Expand Down Expand Up @@ -471,3 +473,60 @@ def test_grouping_hash_14749() -> None:
.select(pl.col("x").max().over("grp"))["x"]
.value_counts()
).to_dict(as_series=False) == {"x": [3], "count": [1004]}


@pytest.mark.parametrize(
("in_dtype", "out_dtype"),
[
(pl.Boolean, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
(pl.UInt64, pl.Float64),
(pl.Int8, pl.Float64),
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_horizontal_mean_single_column(
in_dtype: PolarsDataType,
out_dtype: PolarsDataType,
) -> None:
out = (
pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)})
.select(pl.mean_horizontal(pl.all()))
.collect()
)

assert_frame_equal(out, pl.DataFrame({"a": pl.Series([1.0, 0.0], dtype=out_dtype)}))


def test_horizontal_mean_in_groupby_15115() -> None:
nbr_records = 1000
out = (
pl.LazyFrame(
{
"w": [None, "one", "two", "three"] * nbr_records,
"x": [None, None, "two", "three"] * nbr_records,
"y": [None, None, None, "three"] * nbr_records,
"z": [None, None, None, None] * nbr_records,
}
)
.select(pl.mean_horizontal(pl.all().is_null()).alias("mean_null"))
.group_by("mean_null")
.len()
.sort(by="mean_null")
.collect()
)
assert_frame_equal(
out,
pl.DataFrame(
{
"mean_null": pl.Series([0.25, 0.5, 0.75, 1.0], dtype=pl.Float64),
"len": pl.Series([nbr_records] * 4, dtype=pl.UInt32),
}
),
)
32 changes: 31 additions & 1 deletion py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from collections import OrderedDict
from datetime import date, timedelta
from typing import Any, Iterator, Mapping
from typing import TYPE_CHECKING, Any, Iterator, Mapping

import pytest

import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars.type_aliases import PolarsDataType


class CustomSchema(Mapping[str, Any]):
"""Dummy schema object for testing compatibility with Mapping."""
Expand Down Expand Up @@ -637,3 +640,30 @@ def test_literal_subtract_schema_13284() -> None:
def test_schema_boolean_sum_horizontal() -> None:
lf = pl.LazyFrame({"a": [True, False]}).select(pl.sum_horizontal("a"))
assert lf.schema == OrderedDict([("a", pl.UInt32)])


@pytest.mark.parametrize(
("in_dtype", "out_dtype"),
[
(pl.Boolean, pl.Float64),
(pl.UInt8, pl.Float64),
(pl.UInt16, pl.Float64),
(pl.UInt32, pl.Float64),
(pl.UInt64, pl.Float64),
(pl.Int8, pl.Float64),
(pl.Int16, pl.Float64),
(pl.Int32, pl.Float64),
(pl.Int64, pl.Float64),
(pl.Float32, pl.Float32),
(pl.Float64, pl.Float64),
],
)
def test_schema_mean_horizontal_single_column(
in_dtype: PolarsDataType,
out_dtype: PolarsDataType,
) -> None:
lf = pl.LazyFrame({"a": pl.Series([1, 0], dtype=in_dtype)}).select(
pl.mean_horizontal(pl.all())
)

assert lf.schema == OrderedDict([("a", out_dtype)])

0 comments on commit 82cb292

Please sign in to comment.