Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SemanticSegmentationTask: fix ignore_index weighting #1331

Merged
merged 2 commits into from
May 14, 2023

Conversation

adamjstewart
Copy link
Collaborator

Fixes the second half of #1245. The following script can be used to confirm the bug and the fix:

import torch
from torchmetrics.classification import MulticlassAccuracy

metric = MulticlassAccuracy(num_classes=3, ignore_index=0, average="weighted")
prediction = torch.ones(20, 10)

truth = torch.zeros_like(prediction)
truth[9] = 1
truth[10:] = 2

metric(prediction[:10], truth[:10])
metric(prediction[10:], truth[10:])
print(metric.compute())
metric.reset()

truth = torch.zeros_like(prediction)
truth[10] = 1
truth[11:] = 2

metric(prediction[:10], truth[:10])
metric(prediction[10:], truth[10:])
print(metric.compute())

The default averaging technique "macro" produces the wrong values. Both "micro" and "weighted" compute the correct values. I believe "micro" stores the entire batch in memory and computes the metric at the end of each epoch, while "weighted" computes the metric and percent of non-ignored pixels for each batch and computes the metric at the end of each epoch, so "weighted" should have less memory requirements.

@adamjstewart adamjstewart added this to the 0.4.2 milestone May 12, 2023
@github-actions github-actions bot added the trainers PyTorch Lightning trainers label May 12, 2023
@isaaccorley
Copy link
Collaborator

isaaccorley commented May 12, 2023

We should still use micro instead since this represents overall mIoU. Most papers do not show weighted mIoU. Weighted means it's a weighted average across the individual classes, not across batches.

@isaaccorley
Copy link
Collaborator

Great catch on this actually. The majority of other torchmetrics default to average='micro' but MultiClassJaccardIndex does not.

@adamjstewart
Copy link
Collaborator Author

Would be good if someone could test this and make sure memory usage doesn't explode.

@adamjstewart adamjstewart merged commit e262134 into main May 14, 2023
@adamjstewart adamjstewart deleted the fixes/ignore_index branch May 14, 2023 16:02
@adamjstewart adamjstewart modified the milestones: 0.4.2, 0.5.0 Sep 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants