-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify Jaccard, Dice and Tversky losses #927
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zifuwanggg, thanks for updating it! Looks good to me!
Can you, please design any test case for updated losses? Similar to your snippets, you can get expected results from the current dice implementation and test it against the updated function.
P.S. just figured out there are tests for these functions 😄 lets see if they pass
intersection = (output_sum + target_sum - difference) / 2 # TP | ||
fp = output_sum - intersection | ||
fn = target_sum - intersection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have this moved out of the if/else
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @qubvel, thanks for reviewing!
Since dim=None
are default in torch.norm
and torch.sum
, should we just remove the if/else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed the if/else, and changed torch.norm
in my previous commit to torch.linalg.vector_norm
, since torch.norm
is deprecated and may be removed in a future PyTorch release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this! I think we are good to go!
The Jaccard, Dice and Tversky losses in
losses._functional
are modified based on JDTLoss.Since Jaccard and Dice losses are special cases of the Tversky loss [1], the implementation is simplified by calling
soft_tversky_score
when calculating bothjaccard_score
anddice_score
.The original loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as$\frac{|x|_1 + |y|_1 - |x-y|_1}{2}$ . This reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels), while resolving the issue with soft labels [1, 2].
Example
References
[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.
[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.