Skip to content

Commit

Permalink
fix: Improve histogram bin logic (pola-rs#18761)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller authored Nov 19, 2024
1 parent 6c34d59 commit 282ed31
Show file tree
Hide file tree
Showing 4 changed files with 615 additions and 161 deletions.
295 changes: 193 additions & 102 deletions crates/polars-ops/src/chunked_array/hist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,143 +3,234 @@ use std::fmt::Write;
use num_traits::ToPrimitive;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use polars_utils::total_ord::ToTotalOrd;

fn compute_hist<T>(
const DEFAULT_BIN_COUNT: usize = 10;

fn get_breaks<T>(
ca: &ChunkedArray<T>,
bin_count: Option<usize>,
bins: Option<&[f64]>,
include_category: bool,
include_breakpoint: bool,
) -> Series
) -> PolarsResult<(Vec<f64>, bool, bool)>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let mut lower_bound: f64;
let (breaks, count) = if let Some(bins) = bins {
let mut breaks = Vec::with_capacity(bins.len() + 1);
breaks.extend_from_slice(bins);
breaks.sort_unstable_by_key(|k| k.to_total_ord());
breaks.push(f64::INFINITY);

let sorted = ca.sort(false);

let mut count: Vec<IdxSize> = Vec::with_capacity(breaks.len());
let mut current_count: IdxSize = 0;
let mut breaks_iter = breaks.iter();

// We start with the lower garbage bin.
// (-inf, B0]
lower_bound = f64::NEG_INFINITY;
let mut upper_bound = *breaks_iter.next().unwrap();

for chunk in sorted.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();

// Not a member of current interval
if !(item <= upper_bound && item > lower_bound) {
loop {
// So we push the previous interval
count.push(current_count);
current_count = 0;
lower_bound = upper_bound;
upper_bound = *breaks_iter.next().unwrap();
if item <= upper_bound && item > lower_bound {
break;
}
let mut pad_lower = false;
let (bins, uniform) = match (bin_count, bins) {
(Some(_), Some(_)) => {
return Err(PolarsError::ComputeError(
"can only provide one of `bin_count` or `bins`".into(),
));
},
(None, Some(bins)) => {
// User-supplied bins. Note these are actually bin edges. Check for monotonicity.
// If we only have one edge, we have no bins.
let bin_len = bins.len();
// We also check for uniformity of bins. We declare uniformity if the difference
// between the largest and smallest bin is < 0.00001 the average bin size.
if bin_len > 1 {
let mut smallest = bins[1] - bins[0];
let mut largest = smallest;
let mut avg_bin_size = smallest;
for i in 1..bins.len() {
let d = bins[i] - bins[i - 1];
if d <= 0.0 {
return Err(PolarsError::ComputeError(
"bins must increase monotonically".into(),
));
}
if d > largest {
largest = d;
} else if d < smallest {
smallest = d;
}
avg_bin_size += d;
}
current_count += 1;
let uniform = (largest - smallest) / (avg_bin_size / bin_len as f64) < 0.00001;
(bins.to_vec(), uniform)
} else {
(Vec::<f64>::new(), false) // uniformity doesn't matter here
}
},
(bin_count, None) => {
// User-supplied bin count, or 10 by default. Compute edges from the data.
let bin_count = bin_count.unwrap_or(DEFAULT_BIN_COUNT);
let n = ca.len() - ca.null_count();
let (offset, width) = if n == 0 {
// No non-null items; supply unit interval.
(0.0, 1.0 / bin_count as f64)
} else if n == 1 {
// Unit interval around single point
let idx = ca.first_non_null().unwrap();
// SAFETY: idx is guaranteed to contain an element.
let center = unsafe { ca.get_unchecked(idx) }.unwrap().to_f64().unwrap();
(center - 0.5, 1.0 / bin_count as f64)
} else {
// Determine outer bin edges from the data itself
let min_value = ca.min().unwrap().to_f64().unwrap();
let max_value = ca.max().unwrap().to_f64().unwrap();
pad_lower = true;
(min_value, (max_value - min_value) / bin_count as f64)
};
let out = (0..bin_count + 1)
.map(|x| (x as f64 * width) + offset)
.collect::<Vec<f64>>();
(out, true)
},
};
Ok((bins, uniform, pad_lower))
}

// O(n) implementation when buckets are fixed-size.
// We deposit items directly into their buckets.
fn uniform_hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>, include_lower: bool) -> Vec<IdxSize>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let num_bins = breaks.len() - 1;
let mut count: Vec<IdxSize> = vec![0; num_bins];
let min_break: f64 = breaks[0];
let max_break: f64 = breaks[num_bins];
let width = breaks[1] - min_break; // guaranteed at least one bin
let is_integer = !T::get_dtype().is_float();

for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();
if include_lower && item == min_break {
count[0] += 1;
} else if item > min_break && item <= max_break {
let idx = (item - min_break) / width;
// This is needed for numeric stability for integers.
// We can fall directly on a boundary with an integer.
let idx = if is_integer && (idx.round() - idx).abs() < 0.0000001 {
idx.round() - 1.0
} else {
idx.ceil() - 1.0
};
count[idx as usize] += 1;
}
}
// Add last value, this is the garbage bin. E.g. anything that doesn't fit in the bounds.
count.push(current_count);
// Add the remaining buckets
while count.len() < breaks.len() {
count.push(0)
}
// Push lower bound to infinity
lower_bound = f64::NEG_INFINITY;
(breaks, count)
} else if ca.null_count() == ca.len() {
lower_bound = f64::NEG_INFINITY;
let breaks: Vec<f64> = vec![f64::INFINITY];
let count: Vec<IdxSize> = vec![0];
(breaks, count)
} else {
let start = ChunkAgg::min(ca).unwrap().to_f64().unwrap();
let end = ChunkAgg::max(ca).unwrap().to_f64().unwrap();
}
count
}

// If bin_count is omitted, default to the difference between start and stop (unit bins)
let bin_count = if let Some(bin_count) = bin_count {
bin_count
} else {
(end - start).round() as usize
};
// Variable-width bucketing. We sort the items and then move linearly through buckets.
fn hist_count<T>(breaks: &[f64], ca: &ChunkedArray<T>, include_lower: bool) -> Vec<IdxSize>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let exclude_lower = !include_lower;
let num_bins = breaks.len() - 1;
let mut breaks_iter = breaks.iter().skip(1); // Skip the first lower bound
let (min_break, max_break) = (breaks[0], breaks[breaks.len() - 1]);
let mut upper_bound = *breaks_iter.next().unwrap();
let sorted = ca.sort(false).rechunk();
let mut current_count: IdxSize = 0;
let chunk = sorted.downcast_iter().next().unwrap();
let mut count: Vec<IdxSize> = Vec::with_capacity(num_bins);

// Calculate the breakpoints and make the array. The breakpoints form the RHS of the bins.
let interval = (end - start) / (bin_count as f64);
let breaks_iter = (1..(bin_count)).map(|b| start + (b as f64) * interval);
let mut breaks = Vec::with_capacity(breaks_iter.size_hint().0 + 1);
breaks.extend(breaks_iter);

// Extend the left-most edge by 0.1% of the total range to include the minimum value.
let margin = (end - start) * 0.001;
lower_bound = start - margin;
breaks.push(end);

let mut count: Vec<IdxSize> = vec![0; bin_count];
let max_bin = breaks.len() - 1;
for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();
let bin = ((((item - start) / interval).ceil() - 1.0) as usize).min(max_bin);
count[bin] += 1;
'item: for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap();

// Cycle through items until we hit the first bucket.
if item < min_break || (exclude_lower && item == min_break) {
continue;
}

while item > upper_bound {
if item > max_break {
// No more items will fit in any buckets
break 'item;
}

// Finished with prior bucket; push, reset, and move to next.
count.push(current_count);
current_count = 0;
upper_bound = *breaks_iter.next().unwrap();
}
(breaks, count)

// Item is in bound.
current_count += 1;
}
count.push(current_count);
count.resize(num_bins, 0); // If we left early, fill remainder with 0.
count
}

fn compute_hist<T>(
ca: &ChunkedArray<T>,
bin_count: Option<usize>,
bins: Option<&[f64]>,
include_category: bool,
include_breakpoint: bool,
) -> PolarsResult<Series>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let (breaks, uniform, pad_lower) = get_breaks(ca, bin_count, bins)?;
let num_bins = std::cmp::max(breaks.len(), 1) - 1;
let count = if num_bins > 0 && ca.len() > ca.null_count() {
if uniform {
uniform_hist_count(&breaks, ca, pad_lower)
} else {
hist_count(&breaks, ca, pad_lower)
}
} else {
vec![0; num_bins]
};

// Generate output: breakpoint (optional), breaks (optional), count
let mut fields = Vec::with_capacity(3);

if include_breakpoint {
let breakpoints = if num_bins > 0 {
Series::new(PlSmallStr::from_static("breakpoint"), &breaks[1..])
} else {
let empty: &[f64; 0] = &[];
Series::new(PlSmallStr::from_static("breakpoint"), empty)
};
fields.push(breakpoints)
}

if include_category {
// Use AnyValue for formatting.
let mut lower = AnyValue::Float64(lower_bound);
let mut categories =
StringChunkedBuilder::new(PlSmallStr::from_static("category"), breaks.len());

let mut buf = String::new();
for br in &breaks {
let br = AnyValue::Float64(*br);
buf.clear();
write!(buf, "({lower}, {br}]").unwrap();
categories.append_value(buf.as_str());
lower = br;
if num_bins > 0 {
let mut lower = AnyValue::Float64(if pad_lower {
breaks[0] - (breaks[num_bins] - breaks[0]) * 0.001
} else {
breaks[0]
});
let mut buf = String::new();
for br in &breaks[1..] {
let br = AnyValue::Float64(*br);
buf.clear();
write!(buf, "({lower}, {br}]").unwrap();
categories.append_value(buf.as_str());
lower = br;
}
}
let categories = categories
.finish()
.cast(&DataType::Categorical(None, Default::default()))
.unwrap();
fields.push(categories);
};
if include_breakpoint {
fields.insert(
0,
Series::new(PlSmallStr::from_static("breakpoint"), breaks),
)
}

let count = Series::new(PlSmallStr::from_static("count"), count);
fields.push(count);

if fields.len() == 1 {
let out = fields.pop().unwrap();
out.with_name(ca.name().clone())
Ok(if fields.len() == 1 {
fields.pop().unwrap().with_name(ca.name().clone())
} else {
StructChunked::from_series(ca.name().clone(), fields[0].len(), fields.iter())
.unwrap()
.into_series()
}
})
}

pub fn hist_series(
Expand All @@ -165,7 +256,7 @@ pub fn hist_series(

let out = with_match_physical_numeric_polars_type!(s.dtype(), |$T| {
let ca: &ChunkedArray<$T> = s.as_ref().as_ref().as_ref();
compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)
compute_hist(ca, bin_count, bins_arg, include_category, include_breakpoint)?
});
Ok(out)
}
24 changes: 10 additions & 14 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10078,33 +10078,29 @@ def hist(
--------
>>> df = pl.DataFrame({"a": [1, 3, 8, 8, 2, 1, 3]})
>>> df.select(pl.col("a").hist(bins=[1, 2, 3]))
shape: (4, 1)
shape: (2, 1)
┌─────┐
│ a │
│ --- │
│ u32 │
╞═════╡
│ 2 │
│ 1 │
│ 2 │
│ 2 │
└─────┘
>>> df.select(
... pl.col("a").hist(
... bins=[1, 2, 3], include_breakpoint=True, include_category=True
... )
... )
shape: (4, 1)
┌───────────────────────┐
│ a │
│ --- │
│ struct[3] │
╞═══════════════════════╡
│ {1.0,"(-inf, 1.0]",2} │
│ {2.0,"(1.0, 2.0]",1} │
│ {3.0,"(2.0, 3.0]",2} │
│ {inf,"(3.0, inf]",2} │
└───────────────────────┘
shape: (2, 1)
┌──────────────────────┐
│ a │
│ --- │
│ struct[3] │
╞══════════════════════╡
│ {2.0,"(1.0, 2.0]",1} │
│ {3.0,"(2.0, 3.0]",2} │
└──────────────────────┘
"""
if bins is not None:
if isinstance(bins, list):
Expand Down
Loading

0 comments on commit 282ed31

Please sign in to comment.