-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-class focal loss #3250
Comments
Hi, The discussion for adding loss functions to torchvision is here #2980 . |
@oke-aditya Thank you for adding my request to your task list. Is there an estimated timeline for completing #2980? Also in the meantime, do you have any feedback on the proposed code snippet here - is the implementation correct? |
I don't have any estimated timeline for Loss functions. Since, Loss functions not currently in Torchvision Roadmap #3221 . |
Okay understood. Excellent, I look forward to your feedback on the implementation and I'm sure others looking for a reference will find it helpful as well! |
I have been using pretty much the same approach for computing a multi-class focal loss, but I was encountering some numerical instabilities. |
I think the use of In addition I had contact with one of the authors at one point (Tsung-Yi Lin), where he confirmed they are not using softmax, but that it was more a matter of preference. Could you explain the use of the term "multi-class" in the name for me? The focal loss function in torchvision is used in datasets like COCO, which detects multiple classes. Why is that not a multi-class focal loss? |
@hgaiser I think softmax (provided by
A dataset like COCO would be my intended use-case, but I don't see how the existing torchvision implementation can handle this. The docstring for
If I try to predict multiple class labels (below) with a detection model, the existing binary loss function cannot be used. There is no way to encode
|
I think the focal loss in torchvision is Binary focal loss, it is used to train multi-class models as follows. Here is the code that does it Here are implementations of multi-class focal loss, with and without logits as well as binary focal loss, with and without logits. What I understand is current focal loss work for multi-class case by the above logic, this is multi class object detection no doubt. What a truly multi class focal loss would directly calculate loss value from targets 1 - N, against ground truth class. We would not need to binarize them. But will this method work for detection models ? I mean is there any proof of concept how a multi class focal loss is used in detection tasks? |
@oke-aditya Interesting, perhaps my definition of multi-class focal loss is off then since it seems like we're talking about the same goal (training an object detector with multiple foreground classes). In that case, looping through each foreground class seems rather inefficient if we could just compute them all at once with a single softmax + cross entropy as Thomas and @Nuno-Mota suggest. The RetinaNet loop you linked also requires converting the multi-class softmax predicted by the model into a binary representation. With my proposed code, we could use the model prediction directly without having to translate it into another format. Do those advantages make sense why I'd like a native multi-class loss function? If my proposed code is invalid, perhaps the existing loss function could handle the foreground class looping internally. Then we can pass model class predictions directly, and there is less room for user error |
Ah okay, I think I see some misunderstandings. I'll try to explain some things, hopefully that clears things up a bit. By the way I am the author of the RetinaNet PR for torchvision and I have largely implemented keras-retinanet as well.
This is a long post, I will keep it at this. I hope this helped clear things up a bit 👍 . Let me know if you have any questions. |
@hgaiser Thank you for taking time to clarify in a longer explanation - that helps a lot! Previously, I thought softmax + dedicated background class was the accepted approach in object detection, so I didn't realize there was an option to use sigmoid + one-hot encodings. Would this be an appropriate use of the existing loss function according to your instructions? import random
import torch
from torchvision.ops import sigmoid_focal_loss
batch_size = 8
num_classes = 3
# Randomly assign the GT class for each sample in the batch
gt = torch.zeros(batch_size, num_classes)
for b in range(batch_size):
class_idx = random.randint(0, num_classes - 1)
gt[b, class_idx] = 1
# Get fake predictions from our "model" and compare
logits = torch.randn(batch_size, num_classes)
loss = sigmoid_focal_loss(logits, gt) |
@addisonklinke yes I would think that is correct. |
@addisonklinke just curious to see if you had tested this? |
Unfortunately I haven't yet. This project got pushed to the backburner, but I am hoping to get back to it in the next couple of months |
I think that for multi-class classification task, this implementation looks like the correct one. |
@hgaiser Thanks a lot for this explanation that really clarifies everything. I think one additional thing to be aware of when sigmoids are used in the loss is that it should be reflected during inference. This would not be necessary if we just predict the class with the max logit. But it would be needed in situations where we want to exploit the value of probabilities e.g. being able to not decide under a probability threshold for instance. In a nutshell: if model calibration is important, we must reuse sigmoid to infer probabilities if sigmoids were used in the focal loss. |
Hmm I'm not sure I'm following. If you use a sigmoid you can still interpret the classification as a confidence probability (as its values are between 0 and 1), just not w.r.t. the other classes. |
Oh yes of course. I just meant that whatever is used in the loss calculation must also be used to get the probabilities. |
Ah like that. Yeah, then I absolutely agree :). |
Thank you for the explanation @hgaiser. I am using I was wondering if I need to input raw def get_model_and_optimizer(device):
model = UNet(
in_channels=1,
out_classes=NUM_CLASSES,
dimensions=3,
num_encoding_blocks=3,
out_channels_first_layer=8,
normalization='batch',
upsampling_type='linear',
padding=True,
activation='PReLU',
).to(device)
optimizer = torch.optim.AdamW(model.parameters())
return model, optimizer
def run_epoch(epoch_idx, action, loader, model, optimizer):
is_training = action == Action.TRAIN
epoch_losses = []
times = []
model.train(is_training)
for batch_idx, batch in enumerate(tqdm(loader)):
inputs, targets = prepare_batch(batch, device)
optimizer.zero_grad()
with torch.set_grad_enabled(is_training):
logits = model(inputs)
probabilities = F.softmax(logits, dim=CHANNELS_DIMENSION)
# print(probabilities.shape)
# print(targets.shape)
batch_losses = dice_loss(probabilities, targets)
#batch_losses = ce_loss(probabilities, targets)
batch_loss = batch_losses.mean()
if is_training:
batch_loss.backward()
optimizer.step()
times.append(time.time())
epoch_losses.append(batch_loss.item())
epoch_losses = np.array(epoch_losses)
print(f"{action.value} mean loss: {epoch_losses.mean():0.3f}")
return times, epoch_losses |
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, sigmoid_focal_loss() isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
🚀 Feature
Define an official multi-class focal loss function
Motivation
Most object detectors handle more than 1 class, so a multi-class focal loss function would cover more use-cases than the existing binary focal loss released in v0.8.0
Additionally, there are many different implementations of multi-class focal loss floating around on the web (PyTorch forums, Github, etc). As the authors of the RetinaNet paper, Facebook AI Research should provide a definitive version to settle any existing debates
Pitch
To the best of my understanding, this version by Thomas V. in the PyTorch forums seems correct. Please feel free to correct me if this is not the right approach
Alternatives
Individual practitioners continue writing their own
Additional context
The RetinaNet paper doesn't provide any equations to describe multi-class focal loss, so I think that's partially why people currently have varying implementations. In particular
alpha_t
is not defined, so I noticed Thomas and other users don't follow the samealpha --> alpha_t
conversion used in torchvision's current binary focal loss implementationThe text was updated successfully, but these errors were encountered: