Skip to content

Commit

Permalink
Merge pull request #1 from smartcorelib/alanrace-predict-probs
Browse files Browse the repository at this point in the history
Add test to predict probabilities
  • Loading branch information
AlanRace committed Aug 29, 2022
2 parents b4a807e + 61db4eb commit b6fb819
Showing 1 changed file with 101 additions and 2 deletions.
103 changes: 101 additions & 2 deletions src/ensemble/random_forest_classifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ use serde::{Deserialize, Serialize};

use crate::api::{Predictor, SupervisedEstimator};
use crate::error::{Failed, FailedError};
use crate::linalg::Matrix;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::linalg::{BaseMatrix, Matrix};
use crate::math::num::RealNumber;
use crate::tree::decision_tree_classifier::{
which_max, DecisionTreeClassifier, DecisionTreeClassifierParameters, SplitCriterion,
Expand Down Expand Up @@ -316,6 +317,37 @@ impl<T: RealNumber> RandomForestClassifier<T> {
which_max(&result)
}

/// Predict the per-class probabilties for each observation.
/// The probability is calculated as the fraction of trees that predicted a given class
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
let mut result = DenseMatrix::<f64>::zeros(x.shape().0, self.classes.len());

let (n, _) = x.shape();

for i in 0..n {
let row_probs = self.predict_probs_for_row(x, i);

for (j, item) in row_probs.iter().enumerate() {
result.set(i, j, *item);
}
}

Ok(result)
}

fn predict_probs_for_row<M: Matrix<T>>(&self, x: &M, row: usize) -> Vec<f64> {
let mut result = vec![0; self.classes.len()];

for tree in self.trees.iter() {
result[tree.predict_for_row(x, row)] += 1;
}

result
.iter()
.map(|n| *n as f64 / self.trees.len() as f64)
.collect()
}

fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec<usize> {
let class_weight = vec![1.; num_classes];
let nrows = y.len();
Expand All @@ -341,7 +373,7 @@ impl<T: RealNumber> RandomForestClassifier<T> {
}

#[cfg(test)]
mod tests {
mod tests_prob {
use super::*;
use crate::linalg::naive::dense_matrix::DenseMatrix;
use crate::metrics::*;
Expand Down Expand Up @@ -482,4 +514,71 @@ mod tests {

assert_eq!(forest, deserialized_forest);
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn fit_predict_probabilities() {
let x = DenseMatrix::<f64>::from_2d_array(&[
&[5.1, 3.5, 1.4, 0.2],
&[4.9, 3.0, 1.4, 0.2],
&[4.7, 3.2, 1.3, 0.2],
&[4.6, 3.1, 1.5, 0.2],
&[5.0, 3.6, 1.4, 0.2],
&[5.4, 3.9, 1.7, 0.4],
&[4.6, 3.4, 1.4, 0.3],
&[5.0, 3.4, 1.5, 0.2],
&[4.4, 2.9, 1.4, 0.2],
&[4.9, 3.1, 1.5, 0.1],
&[7.0, 3.2, 4.7, 1.4],
&[6.4, 3.2, 4.5, 1.5],
&[6.9, 3.1, 4.9, 1.5],
&[5.5, 2.3, 4.0, 1.3],
&[6.5, 2.8, 4.6, 1.5],
&[5.7, 2.8, 4.5, 1.3],
&[6.3, 3.3, 4.7, 1.6],
&[4.9, 2.4, 3.3, 1.0],
&[6.6, 2.9, 4.6, 1.3],
&[5.2, 2.7, 3.9, 1.4],
]);
let y = vec![
0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];

let classifier = RandomForestClassifier::fit(
&x,
&y,
RandomForestClassifierParameters {
criterion: SplitCriterion::Gini,
max_depth: None,
min_samples_leaf: 1,
min_samples_split: 2,
n_trees: 100,
m: Option::None,
keep_samples: false,
seed: 87,
},
)
.unwrap();

println!("{:?}", classifier.classes);

let results = classifier.predict_probs(&x).unwrap();
println!("{:?}", x.shape());
println!("{:?}", results);
println!("{:?}", results.shape());

assert_eq!(
results,
DenseMatrix::<f64>::from_array(
20,
2,
&[
1.0, 0.78, 0.95, 0.82, 1.0, 0.92, 0.99, 0.96, 0.36, 0.33, 0.02, 0.02, 0.0, 0.0,
0.0, 0.0, 0.03, 0.05, 0.0, 0.02, 0.0, 0.22, 0.05, 0.18, 0.0, 0.08, 0.01, 0.04,
0.64, 0.67, 0.98, 0.98, 1.0, 1.0, 1.0, 1.0, 0.97, 0.95, 1.0, 0.98
]
)
);
assert!(false);
}
}

0 comments on commit b6fb819

Please sign in to comment.