-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Classification metrics overhaul 1/3 [wip] #4835
Conversation
Hello @tadejsv! Thanks for updating this PR.
Comment last updated at 2020-11-24 15:07:16 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tadejsv Great work!
I would ask you to split this PR into a few smaller ones, like an update for each metric /addition as a separate PR some it will be much easier/smoother for us t preview and merge it sooner 🐰
cc: @SkafteNicki @justusschock @teddykoker
I have taken quick look over your code and just some first comments:
Lastly, I will not be the one lecturing here because I am also really badly at making too big PRs, but I don't think this will have a chance at getting merged unless its get divided into more manageable pieces. |
Thanks guys, all fair points. I'll see what I can do about breaking this up into smaller pieces - the thing is that there's a lot of inter-dependence (everything depends on new formatting function, recall et. al depend on stat scores), so I guess then it would only be more managable if PRs were merged sequentially. |
@tadejsv I want to stress that I don't want your work to be in vain. You have done a great work, and I think a lot of the changes can be made part of lightning :] |
@rohitgr7 yep, I saw your PR before :) Thanks for mentioning the idea about using it as a param, I think that would be the best way actually - I can easily add it to my own metric. It'll be done later today (in a split-off PR) |
Sure
…On Tue, Nov 24, 2020 at 10:15 PM Teddy Koker ***@***.***> wrote:
This PR has been split off into multiple smaller PRs, starting with #4837
<#4837>. Can I
close this one @tadejsv <https://github.com/tadejsv>?
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#4835 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ACXVD3CW2FS33TI6AK4DBR3SRQO7FANCNFSM4UAXB73Q>
.
|
This PR overhauls the classification metrics package, providing uniform input type classification and transformation into common format. I apply this to all existing class classification metrics, create a few new ones on top of it, and unify all "binarized" classification metrics (varying-threshold metrics like AUROC coming up next). I also improve the metrics documentation page.
I realize this PR is way too large. I realized that when it was already too late for me to turn back 🙃. To make the job of the reviewers easier (since due to extensive changes git diffs are not so useful), I'll be giving an exhaustive list of all changes + my reasoning for them below.
General (fundamental) changes
I have created a new
_input_format_classification
function (inmetrics/classification/utils
). The job of this function is to a) validate, and b) transform the inputs into a common format. This common format is a binary label indicator array: either(N, C)
, or(N, C, X)
(only for multi-dimensional multi-class inputs).I believe that having such a "central" function is crucial, as it gets rid of code duplication (which was present in PL metrics before), and enables metric developers to focus on developing the metrics themselves, and not on standardizing and validating inputs.
The validation performed on the inputs basically makes sure that they fall into one of the possible input type cases, that the values are consistent with both the type of the inputs and the additional parameters set (e.g. that there is no label higher than
num_classes
in target). The docstrings (and the new "Input types" section in the documentation) give all the details about how the standardization and validation are performed.Here I'll list the parameters of this function (many of which are also present on some metrics), and why I decided to use them:
threshold
: The probability threshold for binarizing binary and multi-label inputs.num_classes
: number of classes. Used to either decide theC
dimension of inputs, or, if this is already implicitly given, to ensure consistency between inputs and number of classes theuser specified when creating the metric (thus ignoring either having to chech this manually in
update
for each metric, or raising error when updating the state, which may not be very clearto the user).
top_k
: for (multi-dimensional) multi-class, if predictions are given as probabilities, selectsthe top k highest probabilities per sample. It's a generalization of the usual procedure, with
k=1
. Currently, only TopKAccuracy metric uses this, but I believe there is more potential to it.is_multiclass
: used for transforming binary or multi-label input to 2-class multi-class and 2-class multi-dimensional multi-class, respectively. And vice versa.Why? This is similar to
multilabel
argument that was (is?) present on some metrics. I believe this is a better name for it, as it also deals with transforming to/from binary. But why is it needed? There are cases where it is not clear what the inputs are: for example, say that both preds and target are of the form [0,1,0,1]. This actually appears to be multi-class (could be the case that is simply happened in this batch that there were only 0s and 1s), so anexplicit instruction is needed to tell the metrics that this is in fact binary. On the other hand, sometimes we would like to treat binary inputs as two class inputs - this is the case used in confusion matrix.
I also experiemented with using
num_classes
to determine this. Besides this being a very confusing approach, requiring several paragraphs to explain clearly, it also does not resolveall ambiguities (is setting
num_classes=1
with 2 class probability predictions a request to treat the data as binary, or an inconsitency of inputs that should raise an error?). So I thinkis_multiclass
is the best approach here.Metrics (changes and additions)
Note that all metrics listed below now have a unified functional and class interface.
Accuracy
The behavior of the metric in case of multi-label and multi-dimensional multi-class inputs is now different: it calculates subset accuracy instead of
1-hamming_loss
(as was the case before). Additionally, atop_k
parameter enables it to be used as TopK accuracy for multi-class inputs (with probabilities) - thanks @rohitgr7 . Thetop_k
format does not yet have an equivalent in sklearn (will in 0.24).Why? Well first, this brings the behavior of the metric in line with
accuracy_score
metric from sklearn. Second, I think this is a much more natural extension of accuracy from multi-class to multi-label data. Multi-label data is, from my experience, often very "sparse" - with very few 1s per row, just like multi-class. So if correctly predicting all those zeros does not count separately for multi-class, neither should it for multi-label inputs.HammingLoss (new)
This is equivalent to
hamming_loss
from sklearn.1-hamming_loss
is a kind of accuracy, and gives the behavior of the old accuracy metric on multi-label and multi-dimensional multi-class inputs.StatScores (new)
Computes stat score, i.e. true positives, false positives, true negatives, false negatives. It is used as a base for many other metrics (recall, precision, fbeta, iou). It is made to work with all types of inputs, and is very configurable. There are two main parameters here:
reduce
: This determines how should the statistics be counted: globally (summing across all labels), by calsses, or by samples. The possible values (micro
,macro
,samples
), correspond to averaging names for metrics such as precision. This is "inspired" by sklearn's averaging argument in such metrics.mdmc_reduce
: In case of multi-dimensional multi-class (mdmc) inputs, how should the statistics be reduced? This is on top of thereduce
argument. The possible values areglobal
(i.e. extra dimensions are actually sample dimensions) andsamplewise
(compute statistics for each sample, taking the extra dimensions as a sample-within-sample dimension).Why? The reason for these two options (right now PL metrics implements the
global
option by default) is that in some "downstream" metrics, such as iou, it is, in my opinion, much more natural to compute the metric per sample, and then average accross samples, rather than join everyhing into one "blob", and compute the averages for this blob. For example, if you are doing image segmentation, it makes more sense to compute the metrics per image, as the model is trained on images, and not blobs :) Also, aggregation of everything may disguise some unwanted behavior (such as inability to predict a minority class), which would be evident if averaging was done per sample (samplewise
).Also, this class metric (and the functional equivalent) now return the stat scores concatenated in a single tensor, instead of returning a tuple. I did this because the standard metrics testing framework in PL does not support non-tensor returns - and the change should be minor for the users.
Recall, Precision, FBeta, F1, IoU (new)
These are all metrics that can be represented as a (quotient) function of "stat scores" - thanks to subclassing
StatScores
their code is extremely simple. Here are the parameters common to all of them:average
: this builds on thereduce
parameter in StatScores. The options here (micro
,macro
,weighted
,none
or None,samples
) are exactly equivalent to the sklearn counterparts, so I won't go into details.mdmc_average
: builds on themdmc_reduce
from StatScores. This decides how to average scores for multi-dimensional multi-class inputs. Already discussed inmdmc_reduce
.Confusion matrix
Here I only adjusted the internals a bit, to use the new input standardization function. As a result, the function now gained the
logits
parameter. I also changed the docstring to make it clear what inputs are expected and how they are treated.Documentation
The main change in the metrics documentation page is that instead of dumping all class/function definitions there with
autoclass
/autofunction
, it now only contains a list (withautosummary
) with short descriptions, which link to the API for each class and function separately - just like in the logging documentation page, for example.Additionally, instead of metrics being organized into "Class Metrics" and "Functional Metrics", they are now organized by topics (Classification, Regression, ...), and within topics split into class and functional, if neccesary
Why? Well they first change is obvious - the way the page is set up now it's very hard for users to get an overview of what metrics are availible, as they need to click on the definition of each one to survey them. All the metrics being in one table with a short description makes this task much easier. Additionally, now that all classes/functions have their own API pages, it is possible to link to the easily with
:class:
- something I use in some metrics.Other issues
I've ofc added tests for every piece of code I touched: metrics should have close to 100% coverage now (except for the legacy functional code). However, this also brings a problem: the tests now take a long time: a bit over 7 minutes for the metrics package. Most of it is due to precision, recall, fbeta and iou, as there are a lot of parameter combinations to go over there (I already skip some), and testing ddp seems to take particularly long. If anyone has any ideas on how to get around this, I'm all ears :)
Also, for some reason locally the ddp tests did not work for me - this has nothing to do with my changes, it happens on the current master branch too. The tests would just hang. I got "around" this by temporarily setting
NUM_PROCESSES=1
while testing - I believe this should still guarantee that ddp stuff works as it should, while enabling the tests to actually run for me locally. We'll see what happens in CI. If anyone has any ideas here, welcome.Why 1/3?
There will be two more parts: second one will be purely incremental, just adding two confusion-matrix based metrics. (cohen's kappa and mathew's correlation coeff). This will not involve any change to existing code.
The third part, however, will be more similar to the first one. It will involve tackling all remaining "varying threshold" metrics - such as precision recall curve, AUROC, etc.
But first let's get this one merged!