Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make optional argmax for y_pred in Confusion Matrix, Precision, Recall, Accuracy #822

Open
vfdev-5 opened this issue Mar 2, 2020 · 9 comments

Comments

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 2, 2020

🚀 Feature

Today, the conditions on the input of the Confusion Matrix, (and Precision, Recall, Accuracy in multiclass case) are the following:

    - `y_pred` must contain logits and has the following shape (batch_size, num_categories, ...)
    - `y` should have the following shape (batch_size, ...) and contains ground-truth class indices
        with or without the background class. During the computation, argmax of `y_pred` is taken to determine predicted classes.

Taking argmax on y_pred can be an option if we would like to determine winning class by some other rule. Let's keep argmax as default behaviour if y_pred is (N, C, ...) and do not apply it if y_pred.shape == y.shape and (N, ...).

@vfdev-5 vfdev-5 changed the title Make optional argmax for y_pred in Confusion Matrix Make optional argmax for y_pred in Confusion Matrix, Precision, Recall, Accuracy Mar 17, 2020
@sdesrozis
Copy link
Contributor

sdesrozis commented Mar 21, 2020

So

  1. argmax should be optional and user should be able to give their own rule.
  2. if y_pred.shape == y.shape and (N, ...), do not apply

Is it ok ?

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Mar 21, 2020

argmax should be optional and user should be able to give their own rule.

If y_pred has C dimension like (N, C, ...) there is no way to compute a metric without taking argmax with y_true as (N, ...). In this case we should take argmax without an option, IMO.

if y_pred.shape == y.shape and (N, ...), do not apply

yes. In this case, user can perform winning class selection in output_transform or anywhere before metrics update.

@sdesrozis
Copy link
Contributor

Ok I do it soon

@sdesrozis
Copy link
Contributor

The following code from precision.py (same for recall.py) can't be done if shape is (N, ...)

elif self._type == "multiclass":
num_classes = y_pred.size(1)
if y.max() + 1 > num_classes:
raise ValueError(
"y_pred contains less classes than y. Number of predicted classes is {}"
" and element in y has invalid class = {}.".format(num_classes, y.max().item() + 1)
)
y = to_onehot(y.view(-1), num_classes=num_classes)
indices = torch.argmax(y_pred, dim=1).view(-1)
y_pred = to_onehot(indices, num_classes=num_classes)

From accuracy.py, implementation is different

elif self._type == "multiclass":
indices = torch.argmax(y_pred, dim=1)
correct = torch.eq(indices, y).view(-1)

@sdesrozis
Copy link
Contributor

Original code

if y.ndimension() + 1 == y_pred.ndimension():
num_classes = y_pred.shape[1]
if num_classes == 1:
update_type = "binary"
self._check_binary_multilabel_cases((y_pred, y))
else:
update_type = "multiclass"
elif y.ndimension() == y_pred.ndimension():
self._check_binary_multilabel_cases((y_pred, y))
if self._is_multilabel:
update_type = "multilabel"
num_classes = y_pred.shape[1]
else:
update_type = "binary"
num_classes = 1

Modified code

        if y.ndimension() + 1 == y_pred.ndimension():
            # `y` is in the following shape of (batch_size, ...) and
            # `y_pred` is in the following shape of (batch_size, num_categories, ...)
            num_classes = y_pred.shape[1]
            if num_classes == 1:
                update_type = "binary"
                self._check_binary_multilabel_cases((y_pred, y))
            else:
                update_type = "multiclass"
        elif y.ndimension() == y_pred.ndimension():
            if self._is_multilabel:
                # `y` and `y_pred` are in the following shape of (batch_size, num_categories, ...)
                self._check_binary_multilabel_cases((y_pred, y))
                update_type = "multilabel"
                num_classes = y_pred.shape[1]
            else:
                # `y` and `y_pred` are in the following shape of (batch_size, ...)
                # binary type is used because it works in update (no argmax)
                # should we introduce a new type ?
                update_type = "binary"
                num_classes = None

The value of the parameter _num_classes is not important, it is used to ensure the consistency of y_pred in update calls. So None should be an acceptable value.

@vfdev-5 thoughts ?

@sdesrozis
Copy link
Contributor

In test

y_pred = torch.rand(10, 4)
y = torch.randint(0, 4, size=(10,)).long()
acc.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert acc._type == "multiclass"
assert acc._num_classes == 4
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

acc.reset()
y_pred_argmax = torch.argmax(y_pred, dim=1)
acc.update((y_pred_argmax, y))
assert acc._type == "binary"
assert acc._num_classes == None
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Mar 22, 2020

If we make argmax optional, then in the branch elif y.ndimension() == y_pred.ndimension(): we can have the following situations:

  • binary case : y: (N, ...), y_pred: (N, ...) and y ** 2 == y and y_pred ** 2 == y (as previously)
  • multilabel case : y: (N, C, ...), y_pred: (N, C, ...) and y ** 2 == y and y_pred ** 2 == y (as previously)
  • multiclass case (new) : y: (N, ...), y_pred: (N, ...) => yes, here we can set num_classes = None

Concerning the tests

y_pred_argmax = torch.argmax(y_pred, dim=1)
acc.update((y_pred_argmax, y))
assert acc._type == "binary"

the last assert looks strange for me if the data is originally multi-class.

@sdesrozis
Copy link
Contributor

Ok, so I use the predicat y ** 2 == y and y_pred ** 2 == y to discriminate, I got it.

@sdesrozis
Copy link
Contributor

Consider y: (N, ...) and y_pred: (N, ...). During update calls, type should switch from mutliclass to binary and vice versa, it depends on the values and the test to check if y and y_pred lead to binary type.

It means we can't infer the type of the metric using update. Right ? Is it an information given by user (we have multilabel, why not binary ?).

@vfdev-5 thoughs ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants