From 663db0334d8b41895df0e955bb8e429b8044a327 Mon Sep 17 00:00:00 2001 From: Alan Race Date: Mon, 11 Jul 2022 16:08:03 +0200 Subject: [PATCH 1/6] Added per-class probability prediction for random forests --- src/ensemble/random_forest_classifier.rs | 33 +++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 247b5025..87062f21 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize}; use crate::api::{Predictor, SupervisedEstimator}; use crate::error::{Failed, FailedError}; -use crate::linalg::Matrix; +use crate::linalg::naive::dense_matrix::DenseMatrix; +use crate::linalg::{BaseMatrix, Matrix}; use crate::math::num::RealNumber; use crate::tree::decision_tree_classifier::{ which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion, @@ -316,6 +317,36 @@ impl RandomForestClassifier { which_max(&result) } + /// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class + pub fn predict_probs>(&self, x: &M) -> Result, Failed> { + let mut result = DenseMatrix::::zeros(x.shape().0, self.classes.len()); + + let (n, _) = x.shape(); + + for i in 0..n { + let row_probs = self.predict_probs_for_row(x, i); + + for j in 0..row_probs.len() { + result.set(i, j, row_probs[j]); + } + } + + Ok(result) + } + + fn predict_probs_for_row>(&self, x: &M, row: usize) -> Vec { + let mut result = vec![0; self.classes.len()]; + + for tree in self.trees.iter() { + result[tree.predict_for_row(x, row)] += 1; + } + + result + .iter() + .map(|n| *n as f64 / self.trees.len() as f64) + .collect() + } + fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec { let class_weight = vec![1.; num_classes]; let nrows = y.len(); From 2603a1f42bfd922c25bb7eac9d2a34b9c44c6705 Mon Sep 17 00:00:00 2001 From: "Lorenzo (Mec-iS)" Date: Wed, 24 Aug 2022 11:44:30 +0100 Subject: [PATCH 2/6] Add test --- src/ensemble/random_forest_classifier.rs | 53 +++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 87062f21..dcfe41a0 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -372,7 +372,7 @@ impl RandomForestClassifier { } #[cfg(test)] -mod tests { +mod tests_prob { use super::*; use crate::linalg::naive::dense_matrix::DenseMatrix; use crate::metrics::*; @@ -513,4 +513,55 @@ mod tests { assert_eq!(forest, deserialized_forest); } + + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + #[test] + fn fit_predict_probabilities() { + let x = DenseMatrix::::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![ + 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + ]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 100, + m: Option::None, + keep_samples: false, + seed: 87, + }, + ) + .unwrap(); + + let results = classifier.predict_probs(&x).unwrap(); + + println!("{:?}", results); + assert!(false); + } } From 61db4ebd90be072fbe9a81ef998a99d6a47d8519 Mon Sep 17 00:00:00 2001 From: "Lorenzo (Mec-iS)" Date: Wed, 24 Aug 2022 12:34:56 +0100 Subject: [PATCH 3/6] Add test --- src/ensemble/random_forest_classifier.rs | 25 ++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dcfe41a0..baf6901c 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -317,7 +317,8 @@ impl RandomForestClassifier { which_max(&result) } - /// Predict the per-class probabilties for each observation. The probability is calculated as the fraction of trees that predicted a given class + /// Predict the per-class probabilties for each observation. + /// The probability is calculated as the fraction of trees that predicted a given class pub fn predict_probs>(&self, x: &M) -> Result, Failed> { let mut result = DenseMatrix::::zeros(x.shape().0, self.classes.len()); @@ -326,8 +327,8 @@ impl RandomForestClassifier { for i in 0..n { let row_probs = self.predict_probs_for_row(x, i); - for j in 0..row_probs.len() { - result.set(i, j, row_probs[j]); + for (j, item) in row_probs.iter().enumerate() { + result.set(i, j, *item); } } @@ -559,9 +560,25 @@ mod tests_prob { ) .unwrap(); - let results = classifier.predict_probs(&x).unwrap(); + println!("{:?}", classifier.classes); + let results = classifier.predict_probs(&x).unwrap(); + println!("{:?}", x.shape()); println!("{:?}", results); + println!("{:?}", results.shape()); + + assert_eq!( + results, + DenseMatrix::::from_array( + 20, + 2, + &[ + 1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0, + 0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04, + 0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98 + ] + ) + ); assert!(false); } } From 7f7b2edca0f326c9106f1e32da7123ac9765ea9f Mon Sep 17 00:00:00 2001 From: Alan Race Date: Mon, 29 Aug 2022 16:25:21 +0200 Subject: [PATCH 4/6] Fixed test by transposing matrix --- src/ensemble/random_forest_classifier.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index baf6901c..02841fcc 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -570,15 +570,15 @@ mod tests_prob { assert_eq!( results, DenseMatrix::::from_array( - 20, 2, + 20, &[ 1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0, 0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04, 0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98 ] ) + .transpose() ); - assert!(false); } } From 28c81eb3584986c35b451b7911d5474ec7a7ad3b Mon Sep 17 00:00:00 2001 From: Alan Race Date: Tue, 30 Aug 2022 11:08:35 +0200 Subject: [PATCH 5/6] Test case now passing without transpose --- src/ensemble/random_forest_classifier.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 02841fcc..c07487c3 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -570,15 +570,14 @@ mod tests_prob { assert_eq!( results, DenseMatrix::::from_array( - 2, 20, + 2, &[ - 1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0, - 0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04, - 0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98 + 1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01, + 0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98 ] ) - .transpose() ); } } From 78e53a28e7aa67df7300f9713dabe8b3271c9d87 Mon Sep 17 00:00:00 2001 From: "Lorenzo (Mec-iS)" Date: Mon, 31 Oct 2022 19:28:24 +0000 Subject: [PATCH 6/6] apply fmt --- src/ensemble/random_forest_classifier.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index 749f4fd6..44bd4e38 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -820,9 +820,7 @@ mod tests { &[6.6, 2.9, 4.6, 1.3], &[5.2, 2.7, 3.9, 1.4], ]); - let y = vec![ - 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - ]; + let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; let classifier = RandomForestClassifier::fit( &x, @@ -832,7 +830,7 @@ mod tests { max_depth: None, min_samples_leaf: 1, min_samples_split: 2, - n_trees: 100, // this is n_estimators in sklearn + n_trees: 100, // this is n_estimators in sklearn m: Option::None, keep_samples: false, seed: 0,