-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added per-class probability prediction for random forests #138
Conversation
@AlanRace thank you for your contribution to Smartcore! The change looks good, but you might want to look at |
I have added a test here but there is something wrong, please take a look: |
Add test to predict probabilities
@Mec-iS Thanks for supplying the test - I am guessing there was a problem due to row-major vs column-major storage of |
Codecov Report
@@ Coverage Diff @@
## development #138 +/- ##
===============================================
+ Coverage 83.40% 84.01% +0.60%
===============================================
Files 78 81 +3
Lines 8377 8751 +374
===============================================
+ Hits 6987 7352 +365
- Misses 1390 1399 +9
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
thanks @AlanRace it is probably better to adhere to the DenseMatrix format, so it would be nice for the method to return the transposed values or directly a DenseMatrix. |
Maybe I am misunderstanding, but It looks like Would you prefer that the matrix returned from |
yeah, probably in the shape returned by |
@morenol @VolodymyrOrlov could you please take a look to the WASM test failing? it looks like we have different results for different targets. Looks like rounding works different for WASM, the results look close but not close enough. |
@@ -553,6 +554,37 @@ impl<T: RealNumber> RandomForestClassifier<T> { | |||
which_max(&result) | |||
} | |||
|
|||
/// Predict the per-class probabilties for each observation. | |||
/// The probability is calculated as the fraction of trees that predicted a given class | |||
pub fn predict_probs<M: Matrix<T>>(&self, x: &M) -> Result<DenseMatrix<f64>, Failed> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in scikit it is called predict_proba, I think that it is better to keep the same name
20, | ||
2, | ||
&[ | ||
1.0, 0.0, 0.78, 0.22, 0.95, 0.05, 0.82, 0.18, 1.0, 0.0, 0.92, 0.08, 0.99, 0.01, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these the expected values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. those are the results as returned by the test. they match among all the targets except Wasm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are failing for me locally, and they also failed in the CI for x86_64-unknown-linux-gnu,
I think that the green checks in the CI are only in the CI jobs that builds the crate but not run tests (32 bits arch)
Hello guys, when does this fucntion will be available ? I totally need it in order to perform model ensembling ! Thanks a lot |
@alexis2804 unfortunately we have problems with some tests, you can take a look at them by fetching this branch |
moved to #211 to solve conflicts |
Added a function to predict the per-class probability of each class for each observation.
probabilities
is a KxC matrix, where K is the number of observations and C is the number of classes. Probabilities are calculated as the fraction of trees in the random forest that predicted the given class.Answer to #50 for random forests.