Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LigerCrossEntropyLoss Reduction Behavior for "None" Mode #435

Merged
merged 4 commits into from
Dec 10, 2024

Conversation

hebiao064
Copy link
Collaborator

@hebiao064 hebiao064 commented Dec 8, 2024

Summary

Closes #421

This pull request addresses an issue in the cross_entropy_forward function where the reduction="none" mode did not behave as expected.

Previously, the function always returned a single scalar value, even when reduction="none" was specified. This update ensures that when reduction="none" is used, the function directly outputs the unreduced loss array (loss_1d) instead of summing it.

Changes Made:

  • Added a condition to handle reduction="none", ensuring the function outputs loss_1d directly.
  • Updated the computation of z_loss to respect the reduction="none" mode.
  • Add test for cases when reduction="none"

Why we pass gradient to output.backward()?

Background on Gradients in PyTorch

  • Scalar Outputs: When a tensor is a scalar (a single number), PyTorch can compute gradients automatically by assuming the scalar has an implicit gradient of 1.0.
  • Non-Scalar Outputs: For tensors that are not scalars, gradients must be provided explicitly because PyTorch cannot infer the shape or distribution of gradients. Without this, it raises the error: "grad can be implicitly created only for scalar outputs."

Why reduction="none" Needs Explicit Gradients

When reduction="none", the loss function does not reduce the per-example loss values into a single scalar. Instead, it outputs a vector of losses, with one value per example in the batch. This means that the loss tensor has multiple values, and PyTorch cannot assume what the gradient for each of these values should be unless explicitly provided.

The Fix

By passing gradient=torch.ones_like(loss) to backward():

  • Gradient Tensor: The torch.ones_like(loss) serves as the gradient tensor. It specifies that each element in the loss tensor contributes equally to the gradients during backpropagation.
  • Shape Match: The gradient tensor's shape matches the loss tensor's shape, fulfilling PyTorch's requirements for non-scalar outputs during backward().

Testing Done

make test

pytest /home/jobuser/Liger-Kernel/test/transformers/test_cross_entropy.py shows:

=================================== 93 passed, 1 warning in 13.18s ===================================
  • Hardware Type: NVIDIA A100-SXM4-80GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403 Tcc0403 requested a review from ByronHsu December 8, 2024 10:56
@ByronHsu
Copy link
Collaborator

Awesome work! Please push branch to the main repo next time, so you can run CI directly. We have disabled CI from the outside fork

@ByronHsu ByronHsu merged commit d790b64 into linkedin:main Dec 10, 2024
3 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CrossEntropyLoss return single value when reduction is "none"
3 participants