-
-
Notifications
You must be signed in to change notification settings - Fork 615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make optional argmax for y_pred in Confusion Matrix, Precision, Recall, Accuracy #822
Comments
So
Is it ok ? |
If
yes. In this case, user can perform winning class selection in |
Ok I do it soon |
The following code from ignite/ignite/metrics/precision.py Lines 144 to 153 in cfb1b3a
From ignite/ignite/metrics/accuracy.py Lines 154 to 156 in cfb1b3a
|
Original code ignite/ignite/metrics/accuracy.py Lines 62 to 77 in cfb1b3a
Modified code if y.ndimension() + 1 == y_pred.ndimension():
# `y` is in the following shape of (batch_size, ...) and
# `y_pred` is in the following shape of (batch_size, num_categories, ...)
num_classes = y_pred.shape[1]
if num_classes == 1:
update_type = "binary"
self._check_binary_multilabel_cases((y_pred, y))
else:
update_type = "multiclass"
elif y.ndimension() == y_pred.ndimension():
if self._is_multilabel:
# `y` and `y_pred` are in the following shape of (batch_size, num_categories, ...)
self._check_binary_multilabel_cases((y_pred, y))
update_type = "multilabel"
num_classes = y_pred.shape[1]
else:
# `y` and `y_pred` are in the following shape of (batch_size, ...)
# binary type is used because it works in update (no argmax)
# should we introduce a new type ?
update_type = "binary"
num_classes = None The value of the parameter @vfdev-5 thoughts ? |
In test y_pred = torch.rand(10, 4)
y = torch.randint(0, 4, size=(10,)).long()
acc.update((y_pred, y))
np_y_pred = y_pred.numpy().argmax(axis=1).ravel()
np_y = y.numpy().ravel()
assert acc._type == "multiclass"
assert acc._num_classes == 4
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute())
acc.reset()
y_pred_argmax = torch.argmax(y_pred, dim=1)
acc.update((y_pred_argmax, y))
assert acc._type == "binary"
assert acc._num_classes == None
assert isinstance(acc.compute(), float)
assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) |
If we make argmax optional, then in the branch
Concerning the tests y_pred_argmax = torch.argmax(y_pred, dim=1)
acc.update((y_pred_argmax, y))
assert acc._type == "binary" the last assert looks strange for me if the data is originally multi-class. |
Ok, so I use the predicat |
Consider It means we can't infer the type of the metric using @vfdev-5 thoughs ? |
🚀 Feature
Today, the conditions on the input of the Confusion Matrix, (and Precision, Recall, Accuracy in multiclass case) are the following:
Taking argmax on
y_pred
can be an option if we would like to determine winning class by some other rule. Let's keep argmax as default behaviour ify_pred
is(N, C, ...)
and do not apply it ify_pred.shape == y.shape
and(N, ...)
.The text was updated successfully, but these errors were encountered: