diff --git a/src/ensemble/random_forest_classifier.rs b/src/ensemble/random_forest_classifier.rs index dabb2480..0cff7810 100644 --- a/src/ensemble/random_forest_classifier.rs +++ b/src/ensemble/random_forest_classifier.rs @@ -580,6 +580,37 @@ impl, Y: Array1>(&self, x: &X) -> Result { + let mut result: R = R::zeros(x.shape().0, self.classes.as_ref().unwrap().len()); + + let (n, _) = x.shape(); + + for i in 0..n { + let row_probs = self.predict_proba_for_row(x, i); + + for (j, item) in row_probs.iter().enumerate() { + result.set((i, j), *item); + } + } + + Ok(result) + } + + fn predict_proba_for_row(&self, x: &X, row: usize) -> Vec { + let mut result = vec![0; self.classes.as_ref().unwrap().len()]; + + for tree in self.trees.as_ref().unwrap().iter() { + result[tree.predict_for_row(x, row)] += 1; + } + + result + .iter() + .map(|n| *n as f64 / self.trees.as_ref().unwrap().len() as f64) + .collect() + } + fn sample_with_replacement(y: &[usize], num_classes: usize, rng: &mut impl Rng) -> Vec { let class_weight = vec![1.; num_classes]; let nrows = y.len(); @@ -607,6 +638,7 @@ impl, Y: Array1::from_2d_array(&[ + &[5.1, 3.5, 1.4, 0.2], + &[4.9, 3.0, 1.4, 0.2], + &[4.7, 3.2, 1.3, 0.2], + &[4.6, 3.1, 1.5, 0.2], + &[5.0, 3.6, 1.4, 0.2], + &[5.4, 3.9, 1.7, 0.4], + &[4.6, 3.4, 1.4, 0.3], + &[5.0, 3.4, 1.5, 0.2], + &[4.4, 2.9, 1.4, 0.2], + &[4.9, 3.1, 1.5, 0.1], + &[7.0, 3.2, 4.7, 1.4], + &[6.4, 3.2, 4.5, 1.5], + &[6.9, 3.1, 4.9, 1.5], + &[5.5, 2.3, 4.0, 1.3], + &[6.5, 2.8, 4.6, 1.5], + &[5.7, 2.8, 4.5, 1.3], + &[6.3, 3.3, 4.7, 1.6], + &[4.9, 2.4, 3.3, 1.0], + &[6.6, 2.9, 4.6, 1.3], + &[5.2, 2.7, 3.9, 1.4], + ]); + let y = vec![0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + + let classifier = RandomForestClassifier::fit( + &x, + &y, + RandomForestClassifierParameters { + criterion: SplitCriterion::Gini, + max_depth: None, + min_samples_leaf: 1, + min_samples_split: 2, + n_trees: 100, // this is n_estimators in sklearn + m: Option::None, + keep_samples: false, + seed: 0, + }, + ) + .unwrap(); + + println!("{:?}", classifier.classes); + + let results: DenseMatrix = classifier.predict_proba(&x).unwrap(); + println!("{:?}", x.shape()); + println!("{:?}", results); + println!("{:?}", results.shape()); + + assert_eq!( + results, + DenseMatrix::::new( + 20, + 2, + vec![ + 1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01, + 0.96, 0.04, 0.36, 0.64, 0.33, 0.67, 0.02, 0.98, 0.02, 0.98, 0.0, 1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 1.0, 0.03, 0.97, 0.05, 0.95, 0.0, 1.0, 0.02, 0.98 + ], + true + ) + ); + } }