Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify Jaccard, Dice and Tversky losses #927

Merged
merged 2 commits into from
Oct 2, 2024
Merged

Modify Jaccard, Dice and Tversky losses #927

merged 2 commits into from
Oct 2, 2024

Conversation

zifuwanggg
Copy link
Contributor

@zifuwanggg zifuwanggg commented Sep 16, 2024

The Jaccard, Dice and Tversky losses in losses._functional are modified based on JDTLoss.

  • Since Jaccard and Dice losses are special cases of the Tversky loss [1], the implementation is simplified by calling soft_tversky_score when calculating both jaccard_score and dice_score.

  • The original loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_1 + |y|_1 - |x-y|_1}{2}$. This reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels), while resolving the issue with soft labels [1, 2].

Example

import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2)

def dice_old(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, dims):
    cardinality = torch.sum(x, dim=dims) + torch.sum(y, dim=dims)
    difference = LA.vector_norm(x - y, ord=1, dim=dims)
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, dims), dice_new(pred, one_hot_label, dims))
print(dice_old(pred, soft_label, dims), dice_new(pred, soft_label, dims))
print(dice_old(pred, pred, dims), dice_new(pred, pred, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

References

[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @zifuwanggg, thanks for updating it! Looks good to me!
Can you, please design any test case for updated losses? Similar to your snippets, you can get expected results from the current dice implementation and test it against the updated function.

P.S. just figured out there are tests for these functions 😄 lets see if they pass

Comment on lines 197 to 199
intersection = (output_sum + target_sum - difference) / 2 # TP
fp = output_sum - intersection
fn = target_sum - intersection
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have this moved out of the if/else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @qubvel, thanks for reviewing!

Since dim=None are default in torch.norm and torch.sum, should we just remove the if/else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the if/else, and changed torch.norm in my previous commit to torch.linalg.vector_norm, since torch.norm is deprecated and may be removed in a future PyTorch release.

Copy link
Collaborator

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I think we are good to go!

@qubvel qubvel merged commit d989faa into qubvel-org:main Oct 2, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants