Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid nan loss when there are labels with no samples in the training data. #12

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chbeltz
Copy link

@chbeltz chbeltz commented Nov 22, 2024

Hello there.

I ran into problems today when trying to do a test run with training data that lacked samples for one of the labels. This causes the class-balanced focal loss to come out as nan.

import torch
from balanced_loss import Loss

samples_per_class = list(torch.tensor([ 310., 2489.,  114.,   17.,    0.,  725.]))
pred = torch.tensor([[4.1951e-04, 1.6066e-02, 3.2661e-03, 5.0763e-01, 1.0739e-03, 4.7154e-01],
        [7.6719e-03, 1.1280e-01, 5.8755e-02, 5.5621e-02, 6.6679e-01, 9.8361e-02],
        [3.0145e-03, 9.3653e-01, 1.7860e-02, 2.4776e-03, 3.6712e-03, 3.6448e-02],
        [1.0764e-03, 3.8136e-03, 4.5988e-03, 8.3224e-04, 9.8502e-01, 4.6638e-03],
        [9.5827e-03, 2.3838e-02, 5.1518e-02, 1.0943e-02, 2.9569e-02, 8.7455e-01]])
yb = torch.tensor([[0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.]])
focal_loss = Loss(
    loss_type="focal_loss",
    samples_per_class=samples_per_class,
    beta=0.999, # class-balanced loss beta
    fl_gamma=2, # focal loss gamma
    class_balanced=True
)
print(focal_loss(pred, torch.argmax(yb, dim=-1).to(torch.int64)))

currently yields

tensor(nan)

/home/user/florist-environment/lib/python3.10/site-packages/balanced_loss/losses.py:111: RuntimeWarning: divide by zero encountered in divide
  weights = (1.0 - self.beta) / np.array(effective_num)
/home/user/florist-environment/lib/python3.10/site-packages/balanced_loss/losses.py:112: RuntimeWarning: invalid value encountered in divide
  weights = weights / np.sum(weights) * effective_num_classes

Adding a safe switch to the Loss class fixes this issue without any changes in weight for the non-zero-sample labels relative to leaving out the zero-sample labels. The loss, however, will come out larger than it would with alternative solution of removing the offending label.
grafik

I can see that this is an edge case. But, it will be helpful for me and I imagine it might also be for others. One could also consider raising a ValueError when no-sample labels are supplied hinting at making use of the safe switch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant