From 83b5042e0f01ce3f22fa3598160cd82b61d66ad4 Mon Sep 17 00:00:00 2001 From: Marshall Date: Thu, 28 Nov 2024 02:35:17 -0500 Subject: [PATCH] fix: Improve binning in `Series.hist` with `bin_count` when all values are the same (#20034) --- crates/polars-ops/src/chunked_array/hist.rs | 10 ++++++++-- py-polars/tests/unit/operations/test_hist.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index 27ab0df27798..ca0b125cb0f3 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -69,8 +69,14 @@ where // Determine outer bin edges from the data itself let min_value = ca.min().unwrap().to_f64().unwrap(); let max_value = ca.max().unwrap().to_f64().unwrap(); - pad_lower = true; - (min_value, (max_value - min_value) / bin_count as f64) + + // All data points are identical--use unit interval. + if min_value == max_value { + (min_value - 0.5, 1.0 / bin_count as f64) + } else { + pad_lower = true; + (min_value, (max_value - min_value) / bin_count as f64) + } }; let out = (0..bin_count + 1) .map(|x| (x as f64 * width) + offset) diff --git a/py-polars/tests/unit/operations/test_hist.py b/py-polars/tests/unit/operations/test_hist.py index 4c9b2e0f94e8..bb94f7b94d9b 100644 --- a/py-polars/tests/unit/operations/test_hist.py +++ b/py-polars/tests/unit/operations/test_hist.py @@ -370,6 +370,8 @@ def test_hist_all_null() -> None: def test_hist_rand(n_values: int, n_null: int) -> None: s_rand = pl.Series([None] * n_null, dtype=pl.Int64) s_values = pl.Series(np.random.randint(0, 100, n_values), dtype=pl.Int64) + if s_values.n_unique() == 1: + pytest.skip("Identical values not tested.") s = pl.concat((s_rand, s_values)) out = s.hist(bin_count=10) @@ -424,3 +426,15 @@ def test_hist_max_boundary_19998() -> None: ) result = s.hist(bin_count=50) assert result["count"].sum() == 4 + + +def test_hist_same_values_20030() -> None: + out = pl.Series([1, 1]).hist(bin_count=2) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([1.0, 1.5], dtype=pl.Float64), + "category": pl.Series(["(0.5, 1.0]", "(1.0, 1.5]"], dtype=pl.Categorical), + "count": pl.Series([2, 0], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(out, expected)