-
Notifications
You must be signed in to change notification settings - Fork 24.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ML] Stratified cross validation split for classification #54087
[ML] Stratified cross validation split for classification #54087
Conversation
Pinging @elastic/ml-core (:ml) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Two minor concerns :)
...elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java
Show resolved
Hide resolved
...sticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java
Outdated
Show resolved
Hide resolved
...elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java
Outdated
Show resolved
Hide resolved
...elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java
Show resolved
Hide resolved
...sticsearch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitter.java
Outdated
Show resolved
Hide resolved
...earch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java
Outdated
Show resolved
Hide resolved
...earch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java
Outdated
Show resolved
Hide resolved
...earch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java
Show resolved
Hide resolved
As classification now works for multiple classes, randomly picking training/test data frame rows is not good enough. This commit introduces a stratified cross validation splitter that maintains the proportion of the each class in the dataset in the sample that is used for training the model.
37fe98a
to
9afc2b6
Compare
@@ -324,7 +324,8 @@ private void refreshIndices(String jobId) { | |||
); | |||
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen()); | |||
|
|||
LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices())); | |||
LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@benwtrent I forgot to address your comment before so I squeezed this one here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
...earch/xpack/ml/dataframe/process/crossvalidation/StratifiedCrossValidationSplitterTests.java
Show resolved
Hide resolved
… (#54104) As classification now works for multiple classes, randomly picking training/test data frame rows is not good enough. This commit introduces a stratified cross validation splitter that maintains the proportion of the each class in the dataset in the sample that is used for training the model. Backport of #54087
As classification now works for multiple classes, randomly
picking training/test data frame rows is not good enough.
This commit introduces a stratified cross validation splitter
that maintains the proportion of the each class in the dataset
in the sample that is used for training the model.