From 56326a90d05c4d2555bbba516c42b09d51f44a16 Mon Sep 17 00:00:00 2001 From: jinlow Date: Sun, 8 May 2022 15:59:06 -0500 Subject: [PATCH] feat: Only use nan-safe-compare if needed. --- crates/discrust_core/src/discretize.rs | 21 +-- crates/discrust_core/src/feature.rs | 245 +++++++++++-------------- crates/discrust_core/src/node.rs | 18 +- crates/discrust_core/src/utils.rs | 29 ++- src/lib.rs | 2 +- 5 files changed, 139 insertions(+), 176 deletions(-) diff --git a/crates/discrust_core/src/discretize.rs b/crates/discrust_core/src/discretize.rs index b3a77b5..320ef3e 100644 --- a/crates/discrust_core/src/discretize.rs +++ b/crates/discrust_core/src/discretize.rs @@ -1,7 +1,7 @@ use crate::errors::DiscrustError; use crate::feature::Feature; use crate::node::{Node, NodePtr}; -use crate::utils::{first_greater_than, nan_safe_compare}; +use crate::utils::nan_safe_compare; use std::cmp::Ordering; use std::collections::VecDeque; @@ -94,7 +94,7 @@ impl Discretizer { self.mono = Some(split_sign); } - let split_idx = first_greater_than(&feature.vals_, &split); + let idx = info.split_idx.unwrap() + node.start + 1; let lhs_node = Node::new( &feature, @@ -105,7 +105,7 @@ impl Discretizer { info.lhs_woe, info.lhs_iv, Some(node.start), - Some(split_idx), + Some(idx), ); let rhs_node = Node::new( &feature, @@ -115,7 +115,7 @@ impl Discretizer { self.mono, info.rhs_woe, info.rhs_iv, - Some(split_idx), + Some(idx), Some(node.stop), ); @@ -167,7 +167,7 @@ impl Discretizer { // If it's an exception value, we return the index negative value. // We start this at -1. So we add 1, to the zero indexed result // of the `exception_idx` function. - if let Some(i) = feature.exception_values.exception_idx(v) { + if let Some(i) = feature.exception_values_.exception_idx(v) { return Ok(-((i + 1) as i64)); } let idx = all_splits @@ -180,17 +180,14 @@ impl Discretizer { } // -1, 4, 10 fn predict_record_woe(&self, v: &f64, feature: &Feature) -> Result { - let excp_idx = feature.exception_values.exception_idx(v); + let excp_idx = feature.exception_values_.exception_idx(v); if let Some(idx) = excp_idx { - if feature.exception_values.totals_ct_[idx] == 0.0 { + if feature.exception_values_.totals_ct_[idx] == 0.0 { return Ok(0.0); } - return Ok(feature.exception_values.woe_[idx]); + return Ok(feature.exception_values_.woe_[idx]); } - let mut node = self - .root_node - .as_ref() - .ok_or(DiscrustError::NotFitted)?; + let mut node = self.root_node.as_ref().ok_or(DiscrustError::NotFitted)?; let w: f64; loop { if node.is_terminal() { diff --git a/crates/discrust_core/src/feature.rs b/crates/discrust_core/src/feature.rs index 79edfe8..0413259 100644 --- a/crates/discrust_core/src/feature.rs +++ b/crates/discrust_core/src/feature.rs @@ -12,10 +12,11 @@ use std::{cmp::Ordering, collections::HashMap}; pub struct Feature { pub vals_: Vec, cuml_ones_ct_: Vec, + cuml_zero_ct_: Vec, cuml_totals_ct_: Vec, - cuml_ones_dist_: Vec, - cuml_zero_dist_: Vec, - pub exception_values: ExceptionValues, + total_ones_: f64, + total_zero_: f64, + pub exception_values_: ExceptionValues, } #[derive(Debug)] @@ -105,151 +106,118 @@ impl Feature { exception_values: &[f64], ) -> Result { // Make exception values. - let mut exception_values = ExceptionValues::new(exception_values); + let mut exception_values_ = ExceptionValues::new(exception_values); // Define all of the stats we will use let mut vals_ = Vec::new(); - let mut ones_ct_ = Vec::new(); - let mut totals_ct_ = Vec::new(); - let mut ones_dist_ = Vec::new(); - let mut zero_dist_ = Vec::new(); + let mut cuml_ones_ct_ = Vec::new(); + let mut cuml_zero_ct_ = Vec::new(); + let mut cuml_totals_ct_ = Vec::new(); // First we will get the index needed to sort the vector x. let mut sort_tuples: Vec<(usize, &f64)> = x.iter().enumerate().collect(); - // Now sort these tuples by the float values of x. - sort_tuples.sort_by(|a, b| nan_safe_compare(a.1, b.1)); - // Now that we have the tuples sorted, we only need the index, we will loop over - // the data again here, to retrieve the sort index. - // Maybe if we need to speed things up again, we can consider skipping this part and - // and only use the the tuples? Here we just set it to the iterator, so it's - // not actually consumed. - let mut sort_index = sort_tuples.iter().map(|(i, _)| *i); - - // Now loop over all columns, collecting aggregate statistics. - let mut zero_ct: Vec = Vec::new(); + let no_exceptions = exception_values.is_empty(); + if no_exceptions { + // Now sort these tuples by the float values of x. + sort_tuples.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); + } else { + sort_tuples.sort_by(|a, b| nan_safe_compare(a.1, b.1)); + + }; + let sort_index = sort_tuples.iter().map(|(i, _)| *i); - // Do this part in one pass for both. - let mut total_ones = 0.0; - let mut total_zero = 0.0; - for (w_, y_) in w.iter().zip(y) { - // Confirm there are no NaN values in the y_ variable, or - // or the weight field. + let mut totals_idx = 0; + let mut first_value = true; + let mut x_ = f64::NAN; + let mut y_; + let mut w_; + let mut total_ones_ = 0.0; + let mut total_zero_ = 0.0; + for i in sort_index { + y_ = y[i]; + w_ = w[i]; + // Some error checking if y_.is_nan() { return Err(DiscrustError::ContainsNaN(String::from("y column"))); } if w_.is_nan() { return Err(DiscrustError::ContainsNaN(String::from("weight column"))); } - - total_ones += w_ * y_; - // I am using greater than 0 here, because - // floats only have partial equality. - total_zero += w_ * ((y_ < &1.0) as i64 as f64); - } - - // Now loop over the data collecting all relevant stats. - // We grab the first item from the iterator, and start the sums - // of all relevant fields. - let mut init_idx = sort_index.next().unwrap_or(0); - - // Check if NaN is in the vector, but not an exception. - // The NaN values will be at the beginning because it's sorted. - if x[init_idx].is_nan() && exception_values.exception_idx(&x[init_idx]) == None { - return Err(DiscrustError::ContainsNaN(String::from( - "x column, but NaN is not an exception value", - ))); - } - - // If this first value is an exception, we need to loop until - // we find a non-exception value. Or consume the data. In which - // case all values in the data are exception_values. - // If the first value is not an exception, we just leave the - // init index alone. - if let Some(idx) = exception_values.exception_idx(&x[init_idx]) { - exception_values.update_exception_values(idx, &w[init_idx], &y[init_idx]); - // Start searching through the vector to find the first non-exception - // value. - loop { - let i_op = sort_index.next(); - if i_op.is_none() { - break; + if !no_exceptions { + let e_idx = exception_values_.exception_idx(&x[i]); + if x[i].is_nan() && e_idx == None { + return Err(DiscrustError::ContainsNaN(String::from( + "x column, but NaN is not an exception value", + ))); } - let i = i_op.unwrap(); - match exception_values.exception_idx(&x[i]) { - Some(idx) => { - exception_values.update_exception_values(idx, &w[i], &y[i]); + // If the value is equal to one of our exception_values_ update the exception_values_ + // and continue. + if let Some(idx) = e_idx { + exception_values_.update_exception_values(idx, &w_, &y_); + if y_ == 1.0 { + total_ones_ += w_; + } else { + total_zero_ += w_; } - // If we have reached a point in the loop where the value - // is no longer an exception, update the init_idx - // and then break out of the loop. - None => { - init_idx = i; - break; - } - }; - } - }; - - // TODO: At this point we run the risk of actually having gone through - // all the values. Because of this, we need to add some check here - // incase we have totally consumed sort_index. - let mut x_ = x[init_idx]; - vals_.push(x_); - totals_ct_.push(w[init_idx]); - ones_ct_.push(w[init_idx] * y[init_idx]); - zero_ct.push(w[init_idx] * ((y[init_idx] < 1.0) as i64 as f64)); - let mut totals_idx = 0; - - // This will start at the second element. - for i in sort_index { - // If the value is equal to one of our exception_values update the exception_values - // and continue. - if let Some(idx) = exception_values.exception_idx(&x[i]) { - exception_values.update_exception_values(idx, &w[i], &y[i]); - continue; + continue; + } } - - // if the value is greater than x_ we know we are at a new - // value and can calculate the distributions, as well as increment - // the totals_idx. - if x_ < x[i] { - // We update x_ to the new value. + // If this is the first value, add to our vectors + // Initializing them. + if first_value { x_ = x[i]; - // We calculate the distribution values - ones_dist_.push(ones_ct_[totals_idx] / total_ones); - zero_dist_.push(zero_ct[totals_idx] / total_zero); - - // Update the values - totals_ct_.push(w[i]); - ones_ct_.push(w[i] * y[i]); - zero_ct.push(w[i] * ((y[i] < 1.0) as i64 as f64)); + cuml_totals_ct_.push(w_); + if y_ == 1.0 { + total_ones_ += w_; + cuml_ones_ct_.push(w_); + cuml_zero_ct_.push(0.0); + } else { + total_zero_ += w_; + cuml_ones_ct_.push(0.0); + cuml_zero_ct_.push(w_); + } + vals_.push(x_); + // If this is a new value, push, and cumulate + // this won't panic, because we know these + // vectors have values + } else if x_ < x[i] { + let t_last = cuml_totals_ct_[totals_idx]; + let o_last = cuml_ones_ct_[totals_idx]; + let z_last = cuml_zero_ct_[totals_idx]; + x_ = x[i]; + cuml_totals_ct_.push(w_ + t_last); + if y_ == 1.0 { + total_ones_ += w_; + cuml_ones_ct_.push(w_ + o_last); + cuml_zero_ct_.push(z_last); + } else { + total_zero_ += w_; + cuml_ones_ct_.push(o_last); + cuml_zero_ct_.push(w_ + z_last); + } vals_.push(x_); totals_idx += 1; } else { - // Otherwise just add this value to our current aggregations - totals_ct_[totals_idx] += w[i]; - ones_ct_[totals_idx] += w[i] * y[i]; - zero_ct[totals_idx] += w[i] * ((y[i] < 1.0) as i64 as f64); + cuml_totals_ct_[totals_idx] += w_; + if y_ == 1.0 { + total_ones_ += w_; + cuml_ones_ct_[totals_idx] += w_; + } else { + total_zero_ += w_; + cuml_zero_ct_[totals_idx] += w_; + } } + first_value = false; } - // Finally add the very last value to our distribution columns - ones_dist_.push(ones_ct_[totals_idx] / total_ones); - zero_dist_.push(zero_ct[totals_idx] / total_zero); - - exception_values.calculate_iv_woe(total_ones, total_zero); - - // Generate cumulative sums left to right. - let cuml_ones_ct_: Vec = cuml_array(&ones_ct_); - let cuml_totals_ct_: Vec = cuml_array(&totals_ct_); - let cuml_ones_dist_: Vec = cuml_array(&ones_dist_); - let cuml_zero_dist_: Vec = cuml_array(&zero_dist_); + exception_values_.calculate_iv_woe(total_ones_, total_zero_); Ok(Feature { vals_, cuml_ones_ct_, + cuml_zero_ct_, cuml_totals_ct_, - cuml_ones_dist_, - cuml_zero_dist_, - exception_values, + total_ones_, + total_zero_, + exception_values_, }) } @@ -271,14 +239,18 @@ impl Feature { let split_idx = split_idx + 1 + start; // Accumulate the left hand side. - let lhs_zero_dist = sum_of_cuml_subarray(&self.cuml_zero_dist_, start, split_idx - 1); - let lhs_ones_dist = sum_of_cuml_subarray(&self.cuml_ones_dist_, start, split_idx - 1); + let lhs_zero_dist = + sum_of_cuml_subarray(&self.cuml_zero_ct_, start, split_idx - 1) / self.total_zero_; + let lhs_ones_dist = + sum_of_cuml_subarray(&self.cuml_ones_ct_, start, split_idx - 1) / self.total_ones_; let lhs_woe = (lhs_ones_dist / lhs_zero_dist).ln(); let lhs_iv = (lhs_ones_dist - lhs_zero_dist) * lhs_woe; // Accumulate the right hand side. - let rhs_zero_dist = sum_of_cuml_subarray(&self.cuml_zero_dist_, split_idx, stop - 1); - let rhs_ones_dist = sum_of_cuml_subarray(&self.cuml_ones_dist_, split_idx, stop - 1); + let rhs_zero_dist = + sum_of_cuml_subarray(&self.cuml_zero_ct_, split_idx, stop - 1) / self.total_zero_; + let rhs_ones_dist = + sum_of_cuml_subarray(&self.cuml_ones_ct_, split_idx, stop - 1) / self.total_ones_; let rhs_woe = (rhs_ones_dist / rhs_zero_dist).ln(); let rhs_iv = (rhs_ones_dist - rhs_zero_dist) * rhs_woe; @@ -311,15 +283,6 @@ fn sum_of_cuml_subarray(x: &[f64], start: usize, stop: usize) -> f64 { } } -fn cuml_array(x: &[f64]) -> Vec { - x.iter() - .scan(0.0, |acc, &x| { - *acc += x; - Some(*acc) - }) - .collect() -} - #[cfg(test)] mod test { use super::*; @@ -347,11 +310,11 @@ mod test { assert_eq!(f.vals_, vec![1.0, 2.0]); assert_eq!(f.cuml_totals_ct_, vec![2.0, 8.0]); assert_eq!(f.cuml_ones_ct_, vec![1.0, 7.0]); - assert_eq!( - f.cuml_ones_dist_, - vec![1.0 / 7.0, (1.0 / 7.0) + (6.0 / 7.0)] - ); - assert_eq!(f.cuml_zero_dist_, vec![1.0 / 1.0, 1.0]); + // assert_eq!( + // f.cuml_ones_dist_, + // vec![1.0 / 7.0, (1.0 / 7.0) + (6.0 / 7.0)] + // ); + // assert_eq!(f.cuml_zero_dist_, vec![1.0 / 1.0, 1.0]); } #[test] @@ -378,7 +341,7 @@ mod test { // (0.011157177565710483, -0.2231435513142097), // (0.011157177565710483, -0.2231435513142097) (0.011157177565710483, -0.2231435513142097), - (0.011157177565710457, -0.22314355131420943) + (0.011157177565710483, -0.2231435513142097) ) ) } diff --git a/crates/discrust_core/src/node.rs b/crates/discrust_core/src/node.rs index a3910f9..8e8a51e 100644 --- a/crates/discrust_core/src/node.rs +++ b/crates/discrust_core/src/node.rs @@ -4,6 +4,7 @@ use std::cmp::PartialEq; #[derive(Debug, PartialEq)] pub struct SplitInfo { pub split: Option, + pub split_idx: Option, pub lhs_iv: Option, pub lhs_woe: Option, pub rhs_iv: Option, @@ -11,9 +12,10 @@ pub struct SplitInfo { } impl SplitInfo { - pub fn new(split: f64, lhs_iv: f64, lhs_woe: f64, rhs_iv: f64, rhs_woe: f64) -> Self { + pub fn new(split: f64, split_idx: usize, lhs_iv: f64, lhs_woe: f64, rhs_iv: f64, rhs_woe: f64) -> Self { SplitInfo { split: Some(split), + split_idx: Some(split_idx), lhs_iv: Some(lhs_iv), lhs_woe: Some(lhs_woe), rhs_iv: Some(rhs_iv), @@ -23,6 +25,7 @@ impl SplitInfo { pub fn new_empty() -> Self { SplitInfo { split: None, + split_idx: None, lhs_iv: None, lhs_woe: None, rhs_iv: None, @@ -103,6 +106,7 @@ impl Node { let mut best_rhs_iv = 0.0; let mut best_rhs_woe = 0.0; let mut best_split = -f64::INFINITY; + let mut best_split_idx = 0; for (i, v) in self.eval_values(feature).iter().enumerate() { let ((lhs_ct, lhs_ones), (rhs_ct, rhs_ones)) = @@ -147,6 +151,7 @@ impl Node { if total_iv > best_iv { best_iv = total_iv; best_split = *v; + best_split_idx = i; best_lhs_iv = lhs_iv; best_lhs_woe = lhs_woe; best_rhs_iv = rhs_iv; @@ -158,6 +163,7 @@ impl Node { } else { SplitInfo::new( best_split, + best_split_idx, best_lhs_iv, best_lhs_woe, best_rhs_iv, @@ -190,6 +196,7 @@ mod test { ); let comp_info = SplitInfo::new( 6.2375, + 3, 0.22001303079783097, -0.6286086594223742, 0.3064140580738649, @@ -220,7 +227,7 @@ mod test { None, None, ); - println!("{:?}", f.exception_values); + println!("{:?}", f.exception_values_); assert_eq!(n.find_best_split(&f).split.unwrap(), 6.2375); let f = Feature::new(&x_, &y_, &w_, &Vec::new()).unwrap(); @@ -235,7 +242,7 @@ mod test { None, None, ); - println!("{:?}", f.exception_values); + println!("{:?}", f.exception_values_); assert_ne!(n.find_best_split(&f).split.unwrap(), 6.2375); } @@ -266,10 +273,11 @@ mod test { println!("{:?}", n.find_best_split(&f)); let test_info = SplitInfo { split: Some(6.4375), + split_idx: Some(0), lhs_iv: Some(f64::INFINITY), lhs_woe: Some(-f64::INFINITY), - rhs_iv: Some(0.08392941911181283), - rhs_woe: Some(-1.0168034503546095), + rhs_iv: Some(0.08392941911181269), + rhs_woe: Some(-1.0168034503546088), }; assert_eq!(n.find_best_split(&f), test_info); } diff --git a/crates/discrust_core/src/utils.rs b/crates/discrust_core/src/utils.rs index e9ee20f..389103e 100644 --- a/crates/discrust_core/src/utils.rs +++ b/crates/discrust_core/src/utils.rs @@ -2,17 +2,21 @@ use num::Float; use std::cmp::Ordering; pub fn nan_safe_compare(i: &T, j: &T) -> Ordering { - match (i.is_nan(), j.is_nan()) { - (true, true) => Ordering::Equal, - (true, false) => Ordering::Less, - (false, true) => Ordering::Greater, - (false, false) => i.partial_cmp(j).unwrap(), + match i.partial_cmp(j) { + Some(o) => o, + None => match (i.is_nan(), j.is_nan()) { + (true, true) => Ordering::Equal, + (true, false) => Ordering::Less, + (false, true) => Ordering::Greater, + (false, false) => Ordering::Equal, + }, } } /// Take a sorted array, and find the position /// of the first value that is less than some target /// value. +#[allow(dead_code)] pub fn first_greater_than(x: &[T], v: &T) -> usize { let mut low = 0; let mut high = x.len(); @@ -47,20 +51,11 @@ mod test { assert_eq!(1, first_greater_than(&v, &1)); assert_eq!(3, first_greater_than(&v, &2)); assert_eq!(5, first_greater_than(&v, &5)); - let i = (&v) - .iter() - .position(|&v| v > 2) - .unwrap(); + let i = (&v).iter().position(|&v| v > 2).unwrap(); assert_eq!(3, i); - let i = (&v) - .iter() - .position(|&v| v > 0) - .unwrap(); + let i = (&v).iter().position(|&v| v > 0).unwrap(); assert_eq!(1, i); - let i = (&v) - .iter() - .position(|&v| v > 5) - .unwrap(); + let i = (&v).iter().position(|&v| v > 5).unwrap(); assert_eq!(5, i); } } diff --git a/src/lib.rs b/src/lib.rs index 737646b..c0d6ab5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -46,7 +46,7 @@ impl Discretizer { .feature .as_ref() .unwrap() - .exception_values + .exception_values_ .to_hashmap()) }