From 0bf4bf20a73b656ee9be0f52f57e6031149effa4 Mon Sep 17 00:00:00 2001 From: jinlow Date: Wed, 4 May 2022 22:44:59 -0500 Subject: [PATCH] Added an optimization with how WOE is caclulated --- crates/discrust_core/src/feature.rs | 24 +++++++++++------------- crates/discrust_core/src/node.rs | 10 +++++----- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/crates/discrust_core/src/feature.rs b/crates/discrust_core/src/feature.rs index 59b452a..79edfe8 100644 --- a/crates/discrust_core/src/feature.rs +++ b/crates/discrust_core/src/feature.rs @@ -1,4 +1,4 @@ -use crate::utils::{first_greater_than, nan_safe_compare}; +use crate::utils::nan_safe_compare; use crate::DiscrustError; use std::{cmp::Ordering, collections::HashMap}; @@ -122,7 +122,7 @@ impl Feature { // Maybe if we need to speed things up again, we can consider skipping this part and // and only use the the tuples? Here we just set it to the iterator, so it's // not actually consumed. - let mut sort_index = sort_tuples.iter().map(|(i, _)| *i); // .collect(); + let mut sort_index = sort_tuples.iter().map(|(i, _)| *i); // Now loop over all columns, collecting aggregate statistics. let mut zero_ct: Vec = Vec::new(); @@ -258,7 +258,7 @@ impl Feature { /// above the split. pub fn split_iv_woe( &self, - split_value: f64, + split_idx: usize, start: usize, stop: usize, ) -> ((f64, f64), (f64, f64)) { @@ -268,7 +268,7 @@ impl Feature { // This means the split_idx, will be one after our actual // split value, thus the left hand side will include // the split value. - let split_idx = first_greater_than(&self.vals_[start..stop], &split_value) + start; + let split_idx = split_idx + 1 + start; // Accumulate the left hand side. let lhs_zero_dist = sum_of_cuml_subarray(&self.cuml_zero_dist_, start, split_idx - 1); @@ -287,11 +287,11 @@ impl Feature { pub fn split_totals_ct_ones_ct( &self, - split_value: f64, + split_idx: usize, start: usize, stop: usize, ) -> ((f64, f64), (f64, f64)) { - let split_idx = first_greater_than(&self.vals_[start..stop], &split_value) + start; + let split_idx = split_idx + 1 + start; let lhs_ct = sum_of_cuml_subarray(&self.cuml_totals_ct_, start, split_idx - 1); let lhs_ones = sum_of_cuml_subarray(&self.cuml_ones_ct_, start, split_idx - 1); @@ -361,7 +361,8 @@ mod test { let w_ = vec![1.0; x_.len()]; let f = Feature::new(&x_, &y_, &w_, &Vec::new()).unwrap(); assert_eq!( - f.split_iv_woe(5.0, 0, f.vals_.len()), + // 0, 4, 5, (Split on 5.0) + f.split_iv_woe(2, 0, f.vals_.len()), ( // (0.022314355131420965, -0.2231435513142097), // (0.018232155679395495, 0.1823215567939548) @@ -372,7 +373,7 @@ mod test { // The same test but on a subset of the data assert_eq!( - f.split_iv_woe(5.0, 1, 5), + f.split_iv_woe(1, 1, 5), ( // (0.011157177565710483, -0.2231435513142097), // (0.011157177565710483, -0.2231435513142097) @@ -388,15 +389,12 @@ mod test { let w_ = vec![1.0; x_.len()]; let f = Feature::new(&x_, &y_, &w_, &Vec::new()).unwrap(); assert_eq!( - f.split_totals_ct_ones_ct(5.0, 0, f.vals_.len()), + f.split_totals_ct_ones_ct(2, 0, f.vals_.len()), ((4.0, 2.0), (5.0, 3.0)) ); // The same test but on a subset of the data - assert_eq!( - f.split_totals_ct_ones_ct(5.0, 1, 5), - ((2.0, 1.0), (2.0, 1.0)) - ) + assert_eq!(f.split_totals_ct_ones_ct(1, 1, 5), ((2.0, 1.0), (2.0, 1.0))) } #[test] fn test_accumulate() { diff --git a/crates/discrust_core/src/node.rs b/crates/discrust_core/src/node.rs index cef8249..a3910f9 100644 --- a/crates/discrust_core/src/node.rs +++ b/crates/discrust_core/src/node.rs @@ -88,7 +88,7 @@ impl Node { fn eval_values<'a>(&self, feature: &'a Feature) -> &'a [f64] { // We do not need to evaluate the last value, as this is not a - // valid value becase there are no records greater than it. + // valid value because there are no records greater than it. feature.vals_[self.start..(self.stop - 1)].as_ref() } @@ -104,9 +104,9 @@ impl Node { let mut best_rhs_woe = 0.0; let mut best_split = -f64::INFINITY; - for v in self.eval_values(feature) { + for (i, v) in self.eval_values(feature).iter().enumerate() { let ((lhs_ct, lhs_ones), (rhs_ct, rhs_ones)) = - feature.split_totals_ct_ones_ct(*v, self.start, self.stop); + feature.split_totals_ct_ones_ct(i, self.start, self.stop); // Min response if (lhs_ones < self.min_pos) | (rhs_ones < self.min_pos) { continue; @@ -119,7 +119,7 @@ impl Node { // Get information value for split. let ((lhs_iv, lhs_woe), (rhs_iv, rhs_woe)) = - feature.split_iv_woe(*v, self.start, self.stop); + feature.split_iv_woe(i, self.start, self.stop); let total_iv = lhs_iv + rhs_iv; if total_iv < self.min_iv { @@ -140,7 +140,7 @@ impl Node { continue; } } else if split_sign == -1 { - continue; + continue; } } // Collect best