Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-class focal loss #3250

Open
Tracked by #2980
addisonklinke opened this issue Jan 13, 2021 · 20 comments
Open
Tracked by #2980

Multi-class focal loss #3250

addisonklinke opened this issue Jan 13, 2021 · 20 comments

Comments

@addisonklinke
Copy link

🚀 Feature

Define an official multi-class focal loss function

Motivation

Most object detectors handle more than 1 class, so a multi-class focal loss function would cover more use-cases than the existing binary focal loss released in v0.8.0

Additionally, there are many different implementations of multi-class focal loss floating around on the web (PyTorch forums, Github, etc). As the authors of the RetinaNet paper, Facebook AI Research should provide a definitive version to settle any existing debates

Pitch

To the best of my understanding, this version by Thomas V. in the PyTorch forums seems correct. Please feel free to correct me if this is not the right approach

import torch
import torch.nn.functional as F

batch_size = 8
num_classes = 5
logits = torch.randn(batch_size, num_classes)
targets = torch.randint(0, num_classes, (batch_size, ))

alpha = 0.25
gamma = 2
ce_loss = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1 - pt) ** gamma * ce_loss

Alternatives

Individual practitioners continue writing their own

Additional context

The RetinaNet paper doesn't provide any equations to describe multi-class focal loss, so I think that's partially why people currently have varying implementations. In particular alpha_t is not defined, so I noticed Thomas and other users don't follow the same alpha --> alpha_t conversion used in torchvision's current binary focal loss implementation

@oke-aditya
Copy link
Contributor

Hi, The discussion for adding loss functions to torchvision is here #2980 .
I think this also sounds like a very particular use case to Computer vision and not a very general loss function like nn.CrossEntropy

@addisonklinke
Copy link
Author

@oke-aditya Thank you for adding my request to your task list. Is there an estimated timeline for completing #2980?

Also in the meantime, do you have any feedback on the proposed code snippet here - is the implementation correct?

@oke-aditya
Copy link
Contributor

I don't have any estimated timeline for Loss functions. Since, Loss functions not currently in Torchvision Roadmap #3221 .
So maybe we might have to wait for such API.
I will have a look at the above snippet and post a version which seems fine to me here.

@addisonklinke
Copy link
Author

Okay understood. Excellent, I look forward to your feedback on the implementation and I'm sure others looking for a reference will find it helpful as well!

@Nuno-Mota
Copy link

I have been using pretty much the same approach for computing a multi-class focal loss, but I was encountering some numerical instabilities.
I did not have as much time to investigate the issue as I would have liked, but it seems to happen when the model is highly confident (estimated probability is approx 1) and 0 < gamma < 1.
As a quick fix, I ended up clamping 1-pt (let's name it clamped_one_minus_pt) and clamped_one_minus_pt ** gamma between 0 + epsilon and 1 - epsilon, which seems to work.
Maybe I simply messed something on my end, but I just wanted to point this out, as it might warrant further investigation.

@hgaiser
Copy link
Contributor

hgaiser commented Jan 15, 2021

Pitch

To the best of my understanding, this version by Thomas V. in the PyTorch forums seems correct. Please feel free to correct me if this is not the right approach

import torch
import torch.nn.functional as F

batch_size = 8
num_classes = 5
logits = torch.randn(batch_size, num_classes)
targets = torch.randint(0, num_classes, (batch_size, ))

alpha = 0.25
gamma = 2
ce_loss = F.cross_entropy(logits, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = alpha * (1 - pt) ** gamma * ce_loss

I think the use of cross_entropy is wrong, or at the very least not what the authors had intended. "cross_entropy combines log_softmax and nll_loss in a single function.", but the RetinaNet paper clearly says they used sigmoid in the loss function.

image

In addition I had contact with one of the authors at one point (Tsung-Yi Lin), where he confirmed they are not using softmax, but that it was more a matter of preference.

Could you explain the use of the term "multi-class" in the name for me? The focal loss function in torchvision is used in datasets like COCO, which detects multiple classes. Why is that not a multi-class focal loss?

@addisonklinke
Copy link
Author

addisonklinke commented Jan 15, 2021

@hgaiser I think softmax (provided by nn.CrossEntropyLoss) is correct since it's the equivalent of sigmoid for multiple classes. I believe the paper mentions sigmoid because they only

introduce the focal loss starting from the cross entropy (CE) loss for binary classification (section 3)

A dataset like COCO would be my intended use-case, but I don't see how the existing torchvision implementation can handle this. The docstring for targets says

Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class)

If I try to predict multiple class labels (below) with a detection model, the existing binary loss function cannot be used. There is no way to encode [0, 1, 2, .. N] class labels into the binary [0, 1] target tensor specified in the docs. Unless of course, we calculate it repeatedly for each class vs. the background label. Does this help clarify?

0 = background
1 = car
2 = horse
3 = tree
...
N = skateboard

@oke-aditya
Copy link
Contributor

oke-aditya commented Jan 15, 2021

I think the focal loss in torchvision is Binary focal loss, it is used to train multi-class models as follows.
We assume each class of given label to be a foreground while every other class foreground while training. As @addisonklinke mentioned, we repeatedly do it for each class.

Here is the code that does it

Here are implementations of multi-class focal loss, with and without logits as well as binary focal loss, with and without logits.

What I understand is current focal loss work for multi-class case by the above logic, this is multi class object detection no doubt.

What a truly multi class focal loss would directly calculate loss value from targets 1 - N, against ground truth class. We would not need to binarize them.

But will this method work for detection models ? I mean is there any proof of concept how a multi class focal loss is used in detection tasks?

@addisonklinke
Copy link
Author

@oke-aditya Interesting, perhaps my definition of multi-class focal loss is off then since it seems like we're talking about the same goal (training an object detector with multiple foreground classes).

In that case, looping through each foreground class seems rather inefficient if we could just compute them all at once with a single softmax + cross entropy as Thomas and @Nuno-Mota suggest. The RetinaNet loop you linked also requires converting the multi-class softmax predicted by the model into a binary representation. With my proposed code, we could use the model prediction directly without having to translate it into another format.

Do those advantages make sense why I'd like a native multi-class loss function? If my proposed code is invalid, perhaps the existing loss function could handle the foreground class looping internally. Then we can pass model class predictions directly, and there is less room for user error

@hgaiser
Copy link
Contributor

hgaiser commented Jan 15, 2021

Ah okay, I think I see some misunderstandings. I'll try to explain some things, hopefully that clears things up a bit. By the way I am the author of the RetinaNet PR for torchvision and I have largely implemented keras-retinanet as well.

  1. Whether you're using softmax or sigmoid, object detection networks (as far as I've seen them) output a vector for each object proposal, predicting it's class. The annotations for this is generally a one-hot vector, where at most one of the values can be 1.
  2. Because of how it is designed, computing the softmax of a prediction vector ensures that the output vector must sum up to 1. How would you then encode a vector that describes a background? The current method is to add a "background" class, so that the value for that class is expected to be 1.
  3. Sigmoid works differently, making each value in the vector independent from eachother. Because the values are independent from each other, the sum does not need to be 1. Therefore, it is easy to encode background as a vector of all zeros. Then each value in the vector can relate to a class from the dataset, without needing to add a "background" class.
  4. The focal loss function in torchvision does indeed use binary cross entropy, but this doesn't mean that it only supports two class classification. Let's zoom in on a specific classification, where we have a one-hot annotation vector and a same size prediction vector. Binary cross entropy looks at each pair of these vectors and treats that as a classification. The annotation vector says a value should be 0, but the prediction vector has it predicted as 0.75, so the loss for that classification is computed. This is repeated for each class in the vector.
  5. This does not mean that there is a loop iterating over each class. The explanation in point 4 is a simplification, in reality it is using matrix operations to compute all of these simultaneously. The for loop that was referenced here loops over images in a mini-batch and computes the loss for each image. Perhaps this can be optimized further to avoid a loop, but that is outside of the topic of discussion here since it does not depend on softmax / sigmoid.
  6. Basically it comes down to preference. You can use either sigmoid or softmax to create a network that classifies COCO classes, both will work well. Whether one works better than the other? I don't know for sure. As I said before, the authors of the Focal Loss paper used sigmoid, so we wrote the implementation using sigmoid too.
  7. I used to use softmax, because it was used in FasterRCNN and derivatives. After implementing keras-retinanet and implementing focal loss with sigmoid, I now prefer sigmoid. My motivation is that: 1) it prevents an unnecessary background class 2) it allows to classify "multi-labels" (not discussing in this post, but softmax does not allow multi-label) 3) it provides more information in the output. A softmax output of [0.5, 0.5] could mean that it is equally certain about both classes, but using sigmoid this could be [0.25, 0.25] or [0.75, 0.75]. To me these mean two different things, in the first the network is just equally uncertain, in the second the network is equally certain. This is not possible with softmax either, because the vector must sum up to 1.

This is a long post, I will keep it at this. I hope this helped clear things up a bit 👍 . Let me know if you have any questions.

@addisonklinke
Copy link
Author

@hgaiser Thank you for taking time to clarify in a longer explanation - that helps a lot! Previously, I thought softmax + dedicated background class was the accepted approach in object detection, so I didn't realize there was an option to use sigmoid + one-hot encodings.

Would this be an appropriate use of the existing loss function according to your instructions?

import random
import torch
from torchvision.ops import sigmoid_focal_loss 

batch_size = 8
num_classes = 3

# Randomly assign the GT class for each sample in the batch
gt = torch.zeros(batch_size, num_classes)
for b in range(batch_size):
    class_idx = random.randint(0, num_classes - 1)
    gt[b, class_idx] = 1

# Get fake predictions from our "model" and compare
logits = torch.randn(batch_size, num_classes)
loss = sigmoid_focal_loss(logits, gt)

@hgaiser
Copy link
Contributor

hgaiser commented Jan 15, 2021

@addisonklinke yes I would think that is correct.

@hgaiser
Copy link
Contributor

hgaiser commented Jan 20, 2021

@addisonklinke just curious to see if you had tested this?

@addisonklinke
Copy link
Author

Unfortunately I haven't yet. This project got pushed to the backburner, but I am hoping to get back to it in the next couple of months

@captainst
Copy link

captainst commented Nov 19, 2021

I think that for multi-class classification task, this implementation looks like the correct one.
Practically, for multi-class (categorical) classification task, the focal loss should address the multi-class imblance problem. For example, given 3 classes: [man,women,cat], the model is supposed to have confidence on "cat" while have some difficulties to distinguish between man and woman. Suppose that the gt is [0, 0, 1] and pred is [0.1, 0.1, 0.8], then the result is quite good that we don't need to penalize it futher. But if, say, the gt is [1, 0, 0] and pred is [0.5, 0.4, 0.1], then we need to penalize on the first 2 items. The way to do it, is to use the **(1 - softmax(pred))**gamma as the focal term. That's why the implementation first calculate the softmax, then the NLLLoss because the softmax is then reused to calculate the focal term.

@CharlesGaydon
Copy link

CharlesGaydon commented Jan 19, 2022

@hgaiser Thanks a lot for this explanation that really clarifies everything.

I think one additional thing to be aware of when sigmoids are used in the loss is that it should be reflected during inference. This would not be necessary if we just predict the class with the max logit. But it would be needed in situations where we want to exploit the value of probabilities e.g. being able to not decide under a probability threshold for instance.

In a nutshell: if model calibration is important, we must reuse sigmoid to infer probabilities if sigmoids were used in the focal loss.

@hgaiser
Copy link
Contributor

hgaiser commented Jan 19, 2022

@hgaiser Thanks a lot for this explanation that really clarifies everything.

I think one additional thing to be aware of when sigmoids are used in the loss is that it should be reflected during inference. This would not be necessary if we just predict the class with the max logit. But it would be needed in situations where we want to exploit the value of probabilities e.g. being able to not decide under a probability threshold for instance.

In a nutshell: if model calibration is important, we must reuse sigmoid to infer probabilities if sigmoids were used in the focal loss.

Hmm I'm not sure I'm following. If you use a sigmoid you can still interpret the classification as a confidence probability (as its values are between 0 and 1), just not w.r.t. the other classes.

@CharlesGaydon
Copy link

Hmm I'm not sure I'm following. If you use a sigmoid you can still interpret the classification as a confidence probability (as its values are between 0 and 1), just not w.r.t. the other classes.

Oh yes of course. I just meant that whatever is used in the loss calculation must also be used to get the probabilities.

@hgaiser
Copy link
Contributor

hgaiser commented Jan 19, 2022

Ah like that. Yeah, then I absolutely agree :).

@pytholic
Copy link

Ah okay, I think I see some misunderstandings. I'll try to explain some things, hopefully that clears things up a bit. By the way I am the author of the RetinaNet PR for torchvision and I have largely implemented keras-retinanet as well.

  1. Whether you're using softmax or sigmoid, object detection networks (as far as I've seen them) output a vector for each object proposal, predicting it's class. The annotations for this is generally a one-hot vector, where at most one of the values can be 1.
  2. Because of how it is designed, computing the softmax of a prediction vector ensures that the output vector must sum up to 1. How would you then encode a vector that describes a background? The current method is to add a "background" class, so that the value for that class is expected to be 1.
  3. Sigmoid works differently, making each value in the vector independent from eachother. Because the values are independent from each other, the sum does not need to be 1. Therefore, it is easy to encode background as a vector of all zeros. Then each value in the vector can relate to a class from the dataset, without needing to add a "background" class.
  4. The focal loss function in torchvision does indeed use binary cross entropy, but this doesn't mean that it only supports two class classification. Let's zoom in on a specific classification, where we have a one-hot annotation vector and a same size prediction vector. Binary cross entropy looks at each pair of these vectors and treats that as a classification. The annotation vector says a value should be 0, but the prediction vector has it predicted as 0.75, so the loss for that classification is computed. This is repeated for each class in the vector.
  5. This does not mean that there is a loop iterating over each class. The explanation in point 4 is a simplification, in reality it is using matrix operations to compute all of these simultaneously. The for loop that was referenced here loops over images in a mini-batch and computes the loss for each image. Perhaps this can be optimized further to avoid a loop, but that is outside of the topic of discussion here since it does not depend on softmax / sigmoid.
  6. Basically it comes down to preference. You can use either sigmoid or softmax to create a network that classifies COCO classes, both will work well. Whether one works better than the other? I don't know for sure. As I said before, the authors of the Focal Loss paper used sigmoid, so we wrote the implementation using sigmoid too.
  7. I used to use softmax, because it was used in FasterRCNN and derivatives. After implementing keras-retinanet and implementing focal loss with sigmoid, I now prefer sigmoid. My motivation is that: 1) it prevents an unnecessary background class 2) it allows to classify "multi-labels" (not discussing in this post, but softmax does not allow multi-label) 3) it provides more information in the output. A softmax output of [0.5, 0.5] could mean that it is equally certain about both classes, but using sigmoid this could be [0.25, 0.25] or [0.75, 0.75]. To me these mean two different things, in the first the network is just equally uncertain, in the second the network is equally certain. This is not possible with softmax either, because the vector must sum up to 1.

This is a long post, I will keep it at this. I hope this helped clear things up a bit +1 . Let me know if you have any questions.

Thank you for the explanation @hgaiser.

I am using UNet pytorch implementation for 3D segmentation by following this example.

I was wondering if I need to input raw logits to the sigmoid_focal_loss or should I input probabilities after applying softmax/sigmoid activation? Here is the code snippet:

def get_model_and_optimizer(device):
    model = UNet(
        in_channels=1,
        out_classes=NUM_CLASSES,
        dimensions=3,
        num_encoding_blocks=3,
        out_channels_first_layer=8,
        normalization='batch',
        upsampling_type='linear',
        padding=True,
        activation='PReLU',
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters())
    return model, optimizer

def run_epoch(epoch_idx, action, loader, model, optimizer):
    is_training = action == Action.TRAIN
    epoch_losses = []
    times = []
    model.train(is_training)
    for batch_idx, batch in enumerate(tqdm(loader)):
        inputs, targets = prepare_batch(batch, device)
        optimizer.zero_grad()
        with torch.set_grad_enabled(is_training):
            logits = model(inputs)
            probabilities = F.softmax(logits, dim=CHANNELS_DIMENSION)
#             print(probabilities.shape)
#             print(targets.shape)
            batch_losses = dice_loss(probabilities, targets)
            #batch_losses = ce_loss(probabilities, targets)
            batch_loss = batch_losses.mean()
            if is_training:
                batch_loss.backward()
                optimizer.step()
            times.append(time.time())
            epoch_losses.append(batch_loss.item())
    epoch_losses = np.array(epoch_losses)
    print(f"{action.value} mean loss: {epoch_losses.mean():0.3f}")
    return times, epoch_losses

dhruvbird added a commit to dhruvbird/pytorch-vision that referenced this issue Jun 18, 2023
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, sigmoid_focal_loss() isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
dhruvbird added a commit to dhruvbird/pytorch-vision that referenced this issue Jul 25, 2023
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
dhruvbird added a commit to dhruvbird/pytorch-vision that referenced this issue Jul 26, 2023
In image segmentation tasks, focal loss is useful when trying to classify an image pixel as one of N classes. Unfortunately, `sigmoid_focal_loss()` isn't useful in such cases. I found that other have been asking for this as well here pytorch#3250 so I decided to submit a PR for the same.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants