Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added per-class probability prediction for random forests #138

Closed

Conversation

AlanRace
Copy link

@AlanRace AlanRace commented Jul 11, 2022

Added a function to predict the per-class probability of each class for each observation.

let probabilities = forest.predict_probs(&data).unwrap();

probabilities is a KxC matrix, where K is the number of observations and C is the number of classes. Probabilities are calculated as the fraction of trees in the random forest that predicted the given class.

Answer to #50 for random forests.

@VolodymyrOrlov
Copy link
Collaborator

@AlanRace thank you for your contribution to Smartcore! The change looks good, but you might want to look at clippy warnings as well as increase test coverage to get this code through automatic checks

@Mec-iS
Copy link
Collaborator

Mec-iS commented Aug 24, 2022

I have added a test here but there is something wrong, please take a look:
AlanRace#1

@AlanRace
Copy link
Author

@Mec-iS Thanks for supplying the test - I am guessing there was a problem due to row-major vs column-major storage of DenseMatrix? Swapping the number of rows and columns in your test and then transposing the matrix results in a passing test.

@codecov-commenter
Copy link

Codecov Report

Merging #138 (7f7b2ed) into development (b4a807e) will increase coverage by 0.60%.
The diff coverage is 100.00%.

@@               Coverage Diff               @@
##           development     #138      +/-   ##
===============================================
+ Coverage        83.40%   84.01%   +0.60%     
===============================================
  Files               78       81       +3     
  Lines             8377     8751     +374     
===============================================
+ Hits              6987     7352     +365     
- Misses            1390     1399       +9     
Impacted Files Coverage Δ
src/ensemble/random_forest_classifier.rs 75.58% <100.00%> (+4.54%) ⬆️
src/linalg/evd.rs 86.06% <0.00%> (ø)
src/linear/lasso_optimizer.rs 94.11% <0.00%> (ø)
src/algorithm/neighbour/mod.rs 78.57% <0.00%> (ø)
src/algorithm/neighbour/distances.rs 66.66% <0.00%> (ø)
src/preprocessing/numerical.rs 88.88% <0.00%> (ø)
src/algorithm/neighbour/fastpair.rs 95.67% <0.00%> (ø)
src/linalg/naive/dense_matrix.rs 80.11% <0.00%> (+0.89%) ⬆️
src/optimization/first_order/lbfgs.rs 94.44% <0.00%> (+1.58%) ⬆️
src/linalg/mod.rs 58.57% <0.00%> (+5.49%) ⬆️
... and 1 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@Mec-iS
Copy link
Collaborator

Mec-iS commented Aug 29, 2022

thanks @AlanRace

it is probably better to adhere to the DenseMatrix format, so it would be nice for the method to return the transposed values or directly a DenseMatrix.

@AlanRace
Copy link
Author

Maybe I am misunderstanding, but predict_probs does return a DenseMatrix.

It looks like DenseMatrix::from_vec (which is called from DenseMatrix::from_array as part of your test) assumes that the given vector is in row-major form, but the entered values in the test are in column-major form.

Would you prefer that the matrix returned from predict_probs is num classes x num observations, rather than the current num observations x num classes?

@Mec-iS
Copy link
Collaborator

Mec-iS commented Aug 30, 2022

yeah, probably in the shape returned by from_vector and from_array is handier. thanks again

@Mec-iS
Copy link
Collaborator

Mec-iS commented Aug 30, 2022

@morenol @VolodymyrOrlov could you please take a look to the WASM test failing? it looks like we have different results for different targets. Looks like rounding works different for WASM, the results look close but not close enough.

@@ -553,6 +554,37 @@ impl<T: RealNumber> RandomForestClassifier<T> {
which_max(&result)
}

/// Predict the per-class probabilties for each observation.
/// The probability is calculated as the fraction of trees that predicted a given class
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in scikit it is called predict_proba, I think that it is better to keep the same name

20,
2,
&[
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these the expected values?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. those are the results as returned by the test. they match among all the targets except Wasm.

Copy link
Collaborator

@morenol morenol Sep 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are failing for me locally, and they also failed in the CI for x86_64-unknown-linux-gnu,

I think that the green checks in the CI are only in the CI jobs that builds the crate but not run tests (32 bits arch)

@alexis2804
Copy link

alexis2804 commented Oct 3, 2022

Hello guys, when does this fucntion will be available ? I totally need it in order to perform model ensembling !

Thanks a lot

@Mec-iS
Copy link
Collaborator

Mec-iS commented Oct 3, 2022

@alexis2804 unfortunately we have problems with some tests, you can take a look at them by fetching this branch

@Mec-iS
Copy link
Collaborator

Mec-iS commented Oct 31, 2022

moved to #211 to solve conflicts

@Mec-iS Mec-iS closed this Oct 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants