Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Classification metrics overhaul 1/3 [wip] #4835

Closed
wants to merge 41 commits into from
Closed

[WIP] Classification metrics overhaul 1/3 [wip] #4835

wants to merge 41 commits into from

Conversation

tadejsv
Copy link
Contributor

@tadejsv tadejsv commented Nov 24, 2020

This PR overhauls the classification metrics package, providing uniform input type classification and transformation into common format. I apply this to all existing class classification metrics, create a few new ones on top of it, and unify all "binarized" classification metrics (varying-threshold metrics like AUROC coming up next). I also improve the metrics documentation page.

I realize this PR is way too large. I realized that when it was already too late for me to turn back 🙃. To make the job of the reviewers easier (since due to extensive changes git diffs are not so useful), I'll be giving an exhaustive list of all changes + my reasoning for them below.

General (fundamental) changes

I have created a new _input_format_classification function (in metrics/classification/utils). The job of this function is to a) validate, and b) transform the inputs into a common format. This common format is a binary label indicator array: either (N, C), or (N, C, X) (only for multi-dimensional multi-class inputs).

I believe that having such a "central" function is crucial, as it gets rid of code duplication (which was present in PL metrics before), and enables metric developers to focus on developing the metrics themselves, and not on standardizing and validating inputs.

The validation performed on the inputs basically makes sure that they fall into one of the possible input type cases, that the values are consistent with both the type of the inputs and the additional parameters set (e.g. that there is no label higher than num_classes in target). The docstrings (and the new "Input types" section in the documentation) give all the details about how the standardization and validation are performed.

Here I'll list the parameters of this function (many of which are also present on some metrics), and why I decided to use them:

  • threshold: The probability threshold for binarizing binary and multi-label inputs.

  • num_classes: number of classes. Used to either decide the C dimension of inputs, or, if this is already implicitly given, to ensure consistency between inputs and number of classes the
    user specified when creating the metric (thus ignoring either having to chech this manually in update for each metric, or raising error when updating the state, which may not be very clear
    to the user).

  • top_k: for (multi-dimensional) multi-class, if predictions are given as probabilities, selects
    the top k highest probabilities per sample. It's a generalization of the usual procedure, with k=1. Currently, only TopKAccuracy metric uses this, but I believe there is more potential to it.

  • is_multiclass: used for transforming binary or multi-label input to 2-class multi-class and 2-class multi-dimensional multi-class, respectively. And vice versa.

    Why? This is similar to multilabel argument that was (is?) present on some metrics. I believe this is a better name for it, as it also deals with transforming to/from binary. But why is it needed? There are cases where it is not clear what the inputs are: for example, say that both preds and target are of the form [0,1,0,1]. This actually appears to be multi-class (could be the case that is simply happened in this batch that there were only 0s and 1s), so an
    explicit instruction is needed to tell the metrics that this is in fact binary. On the other hand, sometimes we would like to treat binary inputs as two class inputs - this is the case used in confusion matrix.

    I also experiemented with using num_classes to determine this. Besides this being a very confusing approach, requiring several paragraphs to explain clearly, it also does not resolve
    all ambiguities (is setting num_classes=1 with 2 class probability predictions a request to treat the data as binary, or an inconsitency of inputs that should raise an error?). So I think is_multiclass is the best approach here.

Metrics (changes and additions)

Note that all metrics listed below now have a unified functional and class interface.

Accuracy

The behavior of the metric in case of multi-label and multi-dimensional multi-class inputs is now different: it calculates subset accuracy instead of 1-hamming_loss (as was the case before). Additionally, a top_k parameter enables it to be used as TopK accuracy for multi-class inputs (with probabilities) - thanks @rohitgr7 . The top_k format does not yet have an equivalent in sklearn (will in 0.24).

Why? Well first, this brings the behavior of the metric in line with accuracy_score metric from sklearn. Second, I think this is a much more natural extension of accuracy from multi-class to multi-label data. Multi-label data is, from my experience, often very "sparse" - with very few 1s per row, just like multi-class. So if correctly predicting all those zeros does not count separately for multi-class, neither should it for multi-label inputs.

HammingLoss (new)

This is equivalent to hamming_loss from sklearn. 1-hamming_loss is a kind of accuracy, and gives the behavior of the old accuracy metric on multi-label and multi-dimensional multi-class inputs.

StatScores (new)

Computes stat score, i.e. true positives, false positives, true negatives, false negatives. It is used as a base for many other metrics (recall, precision, fbeta, iou). It is made to work with all types of inputs, and is very configurable. There are two main parameters here:

  • reduce: This determines how should the statistics be counted: globally (summing across all labels), by calsses, or by samples. The possible values (micro, macro, samples), correspond to averaging names for metrics such as precision. This is "inspired" by sklearn's averaging argument in such metrics.

  • mdmc_reduce: In case of multi-dimensional multi-class (mdmc) inputs, how should the statistics be reduced? This is on top of the reduce argument. The possible values are global (i.e. extra dimensions are actually sample dimensions) and samplewise (compute statistics for each sample, taking the extra dimensions as a sample-within-sample dimension).

    Why? The reason for these two options (right now PL metrics implements the global option by default) is that in some "downstream" metrics, such as iou, it is, in my opinion, much more natural to compute the metric per sample, and then average accross samples, rather than join everyhing into one "blob", and compute the averages for this blob. For example, if you are doing image segmentation, it makes more sense to compute the metrics per image, as the model is trained on images, and not blobs :) Also, aggregation of everything may disguise some unwanted behavior (such as inability to predict a minority class), which would be evident if averaging was done per sample (samplewise).

Also, this class metric (and the functional equivalent) now return the stat scores concatenated in a single tensor, instead of returning a tuple. I did this because the standard metrics testing framework in PL does not support non-tensor returns - and the change should be minor for the users.

Recall, Precision, FBeta, F1, IoU (new)

These are all metrics that can be represented as a (quotient) function of "stat scores" - thanks to subclassing StatScores their code is extremely simple. Here are the parameters common to all of them:

  • average: this builds on the reduce parameter in StatScores. The options here (micro, macro, weighted, none or None, samples) are exactly equivalent to the sklearn counterparts, so I won't go into details.

  • mdmc_average: builds on the mdmc_reduce from StatScores. This decides how to average scores for multi-dimensional multi-class inputs. Already discussed in mdmc_reduce.

Confusion matrix

Here I only adjusted the internals a bit, to use the new input standardization function. As a result, the function now gained the logits parameter. I also changed the docstring to make it clear what inputs are expected and how they are treated.

Documentation

The main change in the metrics documentation page is that instead of dumping all class/function definitions there with autoclass/autofunction, it now only contains a list (with autosummary) with short descriptions, which link to the API for each class and function separately - just like in the logging documentation page, for example.

Additionally, instead of metrics being organized into "Class Metrics" and "Functional Metrics", they are now organized by topics (Classification, Regression, ...), and within topics split into class and functional, if neccesary

Why? Well they first change is obvious - the way the page is set up now it's very hard for users to get an overview of what metrics are availible, as they need to click on the definition of each one to survey them. All the metrics being in one table with a short description makes this task much easier. Additionally, now that all classes/functions have their own API pages, it is possible to link to the easily with :class: - something I use in some metrics.

Other issues

I've ofc added tests for every piece of code I touched: metrics should have close to 100% coverage now (except for the legacy functional code). However, this also brings a problem: the tests now take a long time: a bit over 7 minutes for the metrics package. Most of it is due to precision, recall, fbeta and iou, as there are a lot of parameter combinations to go over there (I already skip some), and testing ddp seems to take particularly long. If anyone has any ideas on how to get around this, I'm all ears :)

Also, for some reason locally the ddp tests did not work for me - this has nothing to do with my changes, it happens on the current master branch too. The tests would just hang. I got "around" this by temporarily setting NUM_PROCESSES=1 while testing - I believe this should still guarantee that ddp stuff works as it should, while enabling the tests to actually run for me locally. We'll see what happens in CI. If anyone has any ideas here, welcome.

Why 1/3?

There will be two more parts: second one will be purely incremental, just adding two confusion-matrix based metrics. (cohen's kappa and mathew's correlation coeff). This will not involve any change to existing code.

The third part, however, will be more similar to the first one. It will involve tackling all remaining "varying threshold" metrics - such as precision recall curve, AUROC, etc.

But first let's get this one merged!

@pep8speaks
Copy link

pep8speaks commented Nov 24, 2020

Hello @tadejsv! Thanks for updating this PR.

Line 23:63: E203 whitespace before ':'

Line 44:1: E731 do not assign a lambda expression, use a def
Line 44:1: E741 ambiguous variable name 'I'
Line 45:1: E731 do not assign a lambda expression, use a def
Line 46:1: E731 do not assign a lambda expression, use a def
Line 47:1: E731 do not assign a lambda expression, use a def
Line 48:1: E731 do not assign a lambda expression, use a def
Line 49:1: E731 do not assign a lambda expression, use a def
Line 50:1: E731 do not assign a lambda expression, use a def
Line 51:1: E731 do not assign a lambda expression, use a def
Line 52:1: E731 do not assign a lambda expression, use a def
Line 53:1: E731 do not assign a lambda expression, use a def
Line 54:1: E731 do not assign a lambda expression, use a def
Line 57:1: E731 do not assign a lambda expression, use a def
Line 58:1: E731 do not assign a lambda expression, use a def
Line 59:1: E731 do not assign a lambda expression, use a def
Line 60:1: E731 do not assign a lambda expression, use a def
Line 61:1: E731 do not assign a lambda expression, use a def
Line 62:1: E731 do not assign a lambda expression, use a def
Line 63:1: E731 do not assign a lambda expression, use a def
Line 64:1: E731 do not assign a lambda expression, use a def
Line 65:1: E731 do not assign a lambda expression, use a def
Line 66:1: E731 do not assign a lambda expression, use a def

Line 208:14: E203 whitespace before ':'

Comment last updated at 2020-11-24 15:07:16 UTC

@tadejsv tadejsv changed the title Cls metrics overhaul pt1 Classification metrics overhaul 1/3 Nov 24, 2020
@Borda Borda added feature Is an improvement or enhancement Important labels Nov 24, 2020
@Borda Borda added this to the 1.1 milestone Nov 24, 2020
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tadejsv Great work!
I would ask you to split this PR into a few smaller ones, like an update for each metric /addition as a separate PR some it will be much easier/smoother for us t preview and merge it sooner 🐰
cc: @SkafteNicki @justusschock @teddykoker

@SkafteNicki
Copy link
Member

I have taken quick look over your code and just some first comments:

  • Good: always good with more metrics! tests also seems very well done :]
  • Up for discussion: I am not sure all the new arguments you add are necessary. For example the logits argument should not be necessary if we state that we expect input to be either labels or probabilities. The metrics API does not have to take care of every little detail IMO.
  • Should be discarded: the autosummary change you have made docs will not be merge. I can say this with certainty because I had a PR earlier this week with the exact same changes (Metrics as autosummary tables #4774) that we did not merge, because we want to keep the sidebar.

Lastly, I will not be the one lecturing here because I am also really badly at making too big PRs, but I don't think this will have a chance at getting merged unless its get divided into more manageable pieces.

@tadejsv
Copy link
Contributor Author

tadejsv commented Nov 24, 2020

Thanks guys, all fair points. I'll see what I can do about breaking this up into smaller pieces - the thing is that there's a lot of inter-dependence (everything depends on new formatting function, recall et. al depend on stat scores), so I guess then it would only be more managable if PRs were merged sequentially.

@SkafteNicki
Copy link
Member

@tadejsv I want to stress that I don't want your work to be in vain. You have done a great work, and I think a lot of the changes can be made part of lightning :]
If you want help, I can try to help guide you on how to split this up.
Are you on slack?

@tadejsv tadejsv changed the title Classification metrics overhaul 1/3 [WIP] Classification metrics overhaul 1/3 Nov 24, 2020
@Borda Borda changed the title [WIP] Classification metrics overhaul 1/3 [WIP] Classification metrics overhaul 1/3 [wip] Nov 24, 2020
@rohitgr7
Copy link
Contributor

rohitgr7 commented Nov 24, 2020

@tadejsv Great work!

BTW I am working on topk metrics #3822. Trying to add it as an argument rather than a separate metrics. Although need to resolve some issues there.

@Borda
Copy link
Member

Borda commented Nov 24, 2020

BTW I am working on topk metrics #3822. Trying to add it as an argument rather than a separate metrics. Although need to resolve some issues there.

@rohitgr7 @tadejsv maybe it would be great if you can sync on slack and do not duplicate the work of each other

@tadejsv
Copy link
Contributor Author

tadejsv commented Nov 24, 2020

@rohitgr7 yep, I saw your PR before :) Thanks for mentioning the idea about using it as a param, I think that would be the best way actually - I can easily add it to my own metric. It'll be done later today (in a split-off PR)

@teddykoker
Copy link
Contributor

This PR has been split off into multiple smaller PRs, starting with #4837. Can I close this one @tadejsv?

@tadejsv
Copy link
Contributor Author

tadejsv commented Nov 24, 2020 via email

@Borda
Copy link
Member

Borda commented Nov 25, 2020

just for the record, the successors PRs are:
#4837 #4838 #4839 #4842
each new PR is blocked by its predecessor and after merging it shall be rebased to reveal the true change so no much reason to read all as they are further in the chain they cumulate more previous changes...
@tadejsv thx again :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants