Counter intuitive behavior of nn.CrossEntropy/nn.NLLLoss with weights and issue with gradient accumulation #72047
Labels
module: loss
Problem is related to loss function
module: nn
Related to torch.nn
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 The feature, motivation and pitch
The behavior of mean reduction in nn.CrossEntropy and nn.NLLLoss is counter intuitive when there are class weights as discussed in #9882. The current behavior performs a weighted average instead of an unweighted average, which is probably what people expect.
This counter intuitive behavior also causes an issue when doing gradient accumulation. In particular, when you adjust the loss function to account for gradient accumulation (i.e. to make the divisor
batch_size
xn_grad_accum_steps
instead of justbatch_size
) you no longer have the exact gradients (i.e. the gradients you would have had if your batch size wasbatch_size
xn_grad_accum_steps
).You can of course address this issue if you use
reduction=sum
and manually averaging the loss, but this is clunky and probably frequently overlooked.Possible solution
A straightforward solution and the most intuitive -- at least to me -- would be
reduction='mean'
perform an unweighted averagereduction='weighted_mean
' for weighted averages (current behavior ofmean
)mean
case, which is probably what users expect to happenAlternatives
No response
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345
The text was updated successfully, but these errors were encountered: