Skip to content

Commit

Permalink
Added an optimization with how WOE is caclulated
Browse files Browse the repository at this point in the history
  • Loading branch information
jinlow committed May 5, 2022
1 parent cc0a42a commit 0bf4bf2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
24 changes: 11 additions & 13 deletions crates/discrust_core/src/feature.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::utils::{first_greater_than, nan_safe_compare};
use crate::utils::nan_safe_compare;
use crate::DiscrustError;
use std::{cmp::Ordering, collections::HashMap};

Expand Down Expand Up @@ -122,7 +122,7 @@ impl Feature {
// Maybe if we need to speed things up again, we can consider skipping this part and
// and only use the the tuples? Here we just set it to the iterator, so it's
// not actually consumed.
let mut sort_index = sort_tuples.iter().map(|(i, _)| *i); // .collect();
let mut sort_index = sort_tuples.iter().map(|(i, _)| *i);

// Now loop over all columns, collecting aggregate statistics.
let mut zero_ct: Vec<f64> = Vec::new();
Expand Down Expand Up @@ -258,7 +258,7 @@ impl Feature {
/// above the split.
pub fn split_iv_woe(
&self,
split_value: f64,
split_idx: usize,
start: usize,
stop: usize,
) -> ((f64, f64), (f64, f64)) {
Expand All @@ -268,7 +268,7 @@ impl Feature {
// This means the split_idx, will be one after our actual
// split value, thus the left hand side will include
// the split value.
let split_idx = first_greater_than(&self.vals_[start..stop], &split_value) + start;
let split_idx = split_idx + 1 + start;

// Accumulate the left hand side.
let lhs_zero_dist = sum_of_cuml_subarray(&self.cuml_zero_dist_, start, split_idx - 1);
Expand All @@ -287,11 +287,11 @@ impl Feature {

pub fn split_totals_ct_ones_ct(
&self,
split_value: f64,
split_idx: usize,
start: usize,
stop: usize,
) -> ((f64, f64), (f64, f64)) {
let split_idx = first_greater_than(&self.vals_[start..stop], &split_value) + start;
let split_idx = split_idx + 1 + start;

let lhs_ct = sum_of_cuml_subarray(&self.cuml_totals_ct_, start, split_idx - 1);
let lhs_ones = sum_of_cuml_subarray(&self.cuml_ones_ct_, start, split_idx - 1);
Expand Down Expand Up @@ -361,7 +361,8 @@ mod test {
let w_ = vec![1.0; x_.len()];
let f = Feature::new(&x_, &y_, &w_, &Vec::new()).unwrap();
assert_eq!(
f.split_iv_woe(5.0, 0, f.vals_.len()),
// 0, 4, 5, (Split on 5.0)
f.split_iv_woe(2, 0, f.vals_.len()),
(
// (0.022314355131420965, -0.2231435513142097),
// (0.018232155679395495, 0.1823215567939548)
Expand All @@ -372,7 +373,7 @@ mod test {

// The same test but on a subset of the data
assert_eq!(
f.split_iv_woe(5.0, 1, 5),
f.split_iv_woe(1, 1, 5),
(
// (0.011157177565710483, -0.2231435513142097),
// (0.011157177565710483, -0.2231435513142097)
Expand All @@ -388,15 +389,12 @@ mod test {
let w_ = vec![1.0; x_.len()];
let f = Feature::new(&x_, &y_, &w_, &Vec::new()).unwrap();
assert_eq!(
f.split_totals_ct_ones_ct(5.0, 0, f.vals_.len()),
f.split_totals_ct_ones_ct(2, 0, f.vals_.len()),
((4.0, 2.0), (5.0, 3.0))
);

// The same test but on a subset of the data
assert_eq!(
f.split_totals_ct_ones_ct(5.0, 1, 5),
((2.0, 1.0), (2.0, 1.0))
)
assert_eq!(f.split_totals_ct_ones_ct(1, 1, 5), ((2.0, 1.0), (2.0, 1.0)))
}
#[test]
fn test_accumulate() {
Expand Down
10 changes: 5 additions & 5 deletions crates/discrust_core/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ impl Node {

fn eval_values<'a>(&self, feature: &'a Feature) -> &'a [f64] {
// We do not need to evaluate the last value, as this is not a
// valid value becase there are no records greater than it.
// valid value because there are no records greater than it.
feature.vals_[self.start..(self.stop - 1)].as_ref()
}

Expand All @@ -104,9 +104,9 @@ impl Node {
let mut best_rhs_woe = 0.0;
let mut best_split = -f64::INFINITY;

for v in self.eval_values(feature) {
for (i, v) in self.eval_values(feature).iter().enumerate() {
let ((lhs_ct, lhs_ones), (rhs_ct, rhs_ones)) =
feature.split_totals_ct_ones_ct(*v, self.start, self.stop);
feature.split_totals_ct_ones_ct(i, self.start, self.stop);
// Min response
if (lhs_ones < self.min_pos) | (rhs_ones < self.min_pos) {
continue;
Expand All @@ -119,7 +119,7 @@ impl Node {

// Get information value for split.
let ((lhs_iv, lhs_woe), (rhs_iv, rhs_woe)) =
feature.split_iv_woe(*v, self.start, self.stop);
feature.split_iv_woe(i, self.start, self.stop);

let total_iv = lhs_iv + rhs_iv;
if total_iv < self.min_iv {
Expand All @@ -140,7 +140,7 @@ impl Node {
continue;
}
} else if split_sign == -1 {
continue;
continue;
}
}
// Collect best
Expand Down

0 comments on commit 0bf4bf2

Please sign in to comment.