Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kd_loss avg over tokens #1885

Merged
merged 2 commits into from
Oct 23, 2024
Merged

Conversation

moussaKam
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
Follow up to issue #1865
In ForwardKLWithChunkedOutputLoss the losss is averaged over non-masked tokens instead of chunks
In this notebook I show an example.

I ran some experiments to compare the old and the new loss. I followed the tutorial here

here's the two commands for the training

CUDA_VISIBLE_DEVICES=7 tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_loss._component_=torchtune.modules.loss.OldForwardKLWithChunkedOutputLoss batch_size=16 gradient_accumulation_steps=8 enable_activation_checkpointing=False metric_logger.log_dir=/tmp/kd_output_old profiler.output_dir=/tmp/kd_output/profiling_outputs_old checkpointer.output_dir=/tmp/Llama-3.2-1B-Instruct_old/
CUDA_VISIBLE_DEVICES=6 tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_loss._component_=torchtune.modules.loss.FixedForwardKLWithChunkedOutputLoss batch_size=16 gradient_accumulation_steps=8 enable_activation_checkpointing=False metric_logger.log_dir=/tmp/kd_output_fixed profiler.output_dir=/tmp/kd_output/profiling_outputs_fixed checkpointer.output_dir=/tmp/Llama-3.2-1B-Instruct_fixed/

The results were as follow:

Dataset Metric Old Loss Fixed Loss
TruthfulQA mc1 0.2766 0.2778
TruthfulQA mc2 0.4348 0.4381
HellaSwag acc 0.4549 0.4542
HellaSwag acc_norm 0.6103 0.6113
CommonSenseQA acc 0.5553 0.5594

Please let me know if further work should be done

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Oct 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1885

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit aee21d4 with merge base 1e5f0d5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 23, 2024
@joecummings
Copy link
Contributor

Thanks for the PR @moussaKam - I'll take a look later today and I've tagged @lindawangg, who implemented the first version of KD

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One design consideration, but looks good.

@@ -78,7 +78,23 @@ def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
self.fkl_loss = ForwardKLLoss(ignore_index)

def non_chunked_forward_kl_loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you opt to re-create the ForwardKLLoss within the function instead of as another class that just returns the unnormalized calculation? I see a couple of options:

  1. Re-use the ForwardKLLoss method but add a param called normalize, which will normalize the loss by non masked tokens
  2. Create a new class called ForwardKLLossUnnormalized that returns the loss directly to be used in this implementation

All of these options assume we have the need to keep the original ForwardKLLoss. I don't know enough to say that we don't, but figured I'd ask if we might want to keep it or if we can just delete.

@moussaKam
Copy link
Contributor Author

@ joecummings you're right it's better to use the existing class by adding a normalize param. Let me know if the modifications are good.

As for ForwardKLLoss I think we should keep it because it's faster if enough memory is available.

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, but I'll also let @lindawangg chime in!

Double checking that the numbers you reported did not change after you made these changes?

@moussaKam
Copy link
Contributor Author

Yeah I double checked the values, they are good. I updated the colab notebook with the changes. But looks like the tests should be updated.

@@ -50,6 +52,8 @@ def forward(
prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (labels != self.ignore_index).int()
if not normalize:
return -torch.sum(x * mask.view(-1), dim=0)
if torch.sum(mask.view(-1), dim=0) == 0:
return torch.tensor(0.0, device=x.device)
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW we have masked_mean which we use in other parts of the codebase for this

def masked_mean(
but not hugely fussed about whether it's used here.

Copy link
Contributor

@lindawangg lindawangg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change makes sense to me. Thanks!

@joecummings joecummings merged commit b02825a into pytorch:main Oct 23, 2024
17 checks passed
@moussaKam moussaKam deleted the kd_loss_avg_tokens branch October 24, 2024 08:07
@xuwenxinedu
Copy link

I'm confused about the annotation
in line 41 : labels (torch.Tensor): Ground truth labels of shape (batch_size, vocab_size) .
and line101: labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens).
Why is it different?

@moussaKam
Copy link
Contributor Author

Hi @xuwenxinedu looks like a typo, it should be num_tokens in both cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants