diff --git a/crates/polars-plan/src/logical_plan/aexpr/schema.rs b/crates/polars-plan/src/logical_plan/aexpr/schema.rs index 3f7b11613744..45e5f7c5cba4 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/schema.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/schema.rs @@ -1,7 +1,9 @@ use super::*; fn float_type(field: &mut Field) { - if field.dtype.is_numeric() && !matches!(&field.dtype, DataType::Float32) { + if (field.dtype.is_numeric() || field.dtype == DataType::Boolean) + && field.dtype != DataType::Float32 + { field.coerce(DataType::Float64) } } diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index 35715f18179c..c93ec2357ba5 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -80,7 +80,7 @@ def test_streaming_group_by_types() -> None: "str_sum": pl.String, "bool_first": pl.Boolean, "bool_last": pl.Boolean, - "bool_mean": pl.Boolean, + "bool_mean": pl.Float64, "bool_sum": pl.UInt32, "date_sum": pl.Date, "date_mean": pl.Date,