-
Notifications
You must be signed in to change notification settings - Fork 440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Kd_loss avg over tokens #1885
Kd_loss avg over tokens #1885
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1885
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit aee21d4 with merge base 1e5f0d5 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Thanks for the PR @moussaKam - I'll take a look later today and I've tagged @lindawangg, who implemented the first version of KD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One design consideration, but looks good.
torchtune/modules/loss/kd_losses.py
Outdated
@@ -78,7 +78,23 @@ def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): | |||
super().__init__() | |||
self.num_output_chunks = num_output_chunks | |||
self.ignore_index = ignore_index | |||
self.fkl_loss = ForwardKLLoss(ignore_index) | |||
|
|||
def non_chunked_forward_kl_loss( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you opt to re-create the ForwardKLLoss
within the function instead of as another class that just returns the unnormalized calculation? I see a couple of options:
- Re-use the
ForwardKLLoss
method but add a param callednormalize
, which will normalize the loss by non masked tokens - Create a new class called
ForwardKLLossUnnormalized
that returns the loss directly to be used in this implementation
All of these options assume we have the need to keep the original ForwardKLLoss
. I don't know enough to say that we don't, but figured I'd ask if we might want to keep it or if we can just delete.
a10b124
to
1e5f0d5
Compare
@ joecummings you're right it's better to use the existing class by adding a normalize param. Let me know if the modifications are good. As for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, but I'll also let @lindawangg chime in!
Double checking that the numbers you reported did not change after you made these changes?
Yeah I double checked the values, they are good. I updated the colab notebook with the changes. But looks like the tests should be updated. |
@@ -50,6 +52,8 @@ def forward( | |||
prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) | |||
x = torch.sum(prod_probs, dim=-1).view(-1) | |||
mask = (labels != self.ignore_index).int() | |||
if not normalize: | |||
return -torch.sum(x * mask.view(-1), dim=0) | |||
if torch.sum(mask.view(-1), dim=0) == 0: | |||
return torch.tensor(0.0, device=x.device) | |||
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW we have masked_mean
which we use in other parts of the codebase for this
torchtune/torchtune/rlhf/rewards.py
Line 101 in 17ba37d
def masked_mean( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change makes sense to me. Thanks!
I'm confused about the annotation |
Hi @xuwenxinedu looks like a typo, it should be num_tokens in both cases |
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
What are the changes made in this PR?
Follow up to issue #1865
In ForwardKLWithChunkedOutputLoss the losss is averaged over non-masked tokens instead of chunks
In this notebook I show an example.
I ran some experiments to compare the old and the new loss. I followed the tutorial here
here's the two commands for the training
The results were as follow:
Please let me know if further work should be done
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example