Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knowledge distillation tutorial #1698

Merged
merged 9 commits into from
Sep 27, 2024
Merged

Conversation

lindawangg
Copy link
Contributor

@lindawangg lindawangg commented Sep 27, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

What are the changes made in this PR?

  • Adds a knowledge distillation tutorial on how to distill Llama3.1 8B into Llama3.2 1B

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
image

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API
  • I have added an example to docs or docstrings

Copy link

pytorch-bot bot commented Sep 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1698

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4ab1789 with merge base a899da2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 27, 2024

This guide will teach you about knowledge distillation (KD) and show you how you can use torchtune to distill a Llama3.1 8B model into Llama3.2 1B.
If you already know what knowledge distillation is and want to get straight to running your own distillation in torchtune,
you can jump to knowledge distillation recipe in torchtune, `knowledge_distillation_single_device.py <https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py>`_.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
you can jump to knowledge distillation recipe in torchtune, `knowledge_distillation_single_device.py <https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py>`_.
you can jump to the knowledge distillation recipe in torchtune, `knowledge_distillation_single_device.py <https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py>`_.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it might be better to just reference the latter parts of this tutorial for people who want to jump ahead rather than pointing to the recipe file,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to link the tutorial section instead of recipe.


.. image:: /_static/img/kd-simplified.png

The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with
The total loss can be configured in many ways. The default KD config in torchtune combines the cross-entropy (CE) loss with

.. image:: /_static/img/kd-simplified.png

The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with
forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss,
the forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss,


The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with
forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss,
which is used in standard KD approaches.Forward KL divergence aims to minimize the difference by forcing the student
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
which is used in standard KD approaches.Forward KL divergence aims to minimize the difference by forcing the student
which is used in standard KD approaches. Forward KL divergence aims to minimize the difference by forcing the student's

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which difference are you referring to? : )

return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

There are some details omitted to simplify the computation, but if you'd like to know more,
you can see the implementation in `ForwardKLLoss <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/torchtune/modules/loss/kd_losses.py>`_.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be able to do

Suggested change
you can see the implementation in `ForwardKLLoss <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/torchtune/modules/loss/kd_losses.py>`_.
you can see the implementation in :class:`torchtune.modules.loss.ForwardKLLoss`.

With torchtune, we can easily apply knowledge distillation to Llama3, as well as other LLM model families.
Let's take a look at how you could distill a model using torchtune's `KD recipe <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/recipes/knowledge_distillation_single_device.py>`_.

First, make sure that you have downloaded the Llama3 weights. For this example, we'll use the Llama3.1-8B as teacher and Llama3.2-1B as student.
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit nit:

Suggested change
First, make sure that you have downloaded the Llama3 weights. For this example, we'll use the Llama3.1-8B as teacher and Llama3.2-1B as student.
First, make sure that you have downloaded all the model weights. For this example, we'll use the Llama3.1-8B as teacher and Llama3.2-1B as student.


tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN>

Then, we will fine-tune the teacher model with using LoRA. Based on our experiments and previous work,
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Sep 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit nit (ignore if you like):

Suggested change
Then, we will fine-tune the teacher model with using LoRA. Based on our experiments and previous work,
Then, we will fine-tune the teacher model with using LoRA. Based on our experiments and previous work,

Comment on lines 120 to 121
and `commonsense_qa <https://github.com/EleutherAI/lm-evaluation-harness/tree/b62b9bd/lm_eval/tasks/commonsense_qa>`_
through `EleutherEval <https://github.com/EleutherAI/lm-evaluation-harness/tree/main>`_.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and `commonsense_qa <https://github.com/EleutherAI/lm-evaluation-harness/tree/b62b9bd/lm_eval/tasks/commonsense_qa>`_
through `EleutherEval <https://github.com/EleutherAI/lm-evaluation-harness/tree/main>`_.
and `commonsense_qa <https://github.com/EleutherAI/lm-evaluation-harness/tree/b62b9bd/lm_eval/tasks/commonsense_qa>`_ tasks
through the EleutherAI `LM evaluation harness <https://github.com/EleutherAI/lm-evaluation-harness/tree/main>`_.


tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device

Ablation studies
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This whole section is awesome

Hyperparameter tuning: learning rate
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

By default, the config has the learning rate as 3e-4, same as LoRA configs. For these experiments,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
By default, the config has the learning rate as 3e-4, same as LoRA configs. For these experiments,
By default, the config has the learning rate as ``3e-4`` which is the same as the LoRA configs. For these experiments,

Comment on lines 3 to 5
============================
Distilling Llama3 8B into 1B
============================
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
============================
Distilling Llama3 8B into 1B
============================
==================================================
Distilling Llama3.1 8B into Llama3.2 1B using Knowledge Distillation
==================================================

^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

By default, the config has the learning rate as 3e-4, same as LoRA configs. For these experiments,
we changed the learning rate from as high as 1e-3 to as low as 1e-5. To change the learning rate,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mega nit: should be able to use latexify this using something like :math:1e^{-1}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed all to use math.


.. code-block:: bash

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=[LR]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's clearer to provide a more concrete (vs general) example here

Suggested change
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=[LR]
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3

Hyperparameter tuning: KD ratio
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In the config, we have the kd_ratio as 0.5, which gives even weightings to both the class and KD loss. In these experiments,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
In the config, we have the kd_ratio as 0.5, which gives even weightings to both the class and KD loss. In these experiments,
In the config, we have the ``kd_ratio`` as 0.5, which gives even weightings to both the class and KD loss. In these experiments,


.. code-block:: bash

tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=[KD_RATIO]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

Qwen2 1.5B to 0.5B
^^^^^^^^^^^^^^^^^^

The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mega nit:

Suggested change
The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of
The KD recipe can also be applied to different model families. Here we look at the effect of KD when the number of

^^^^^^^^^^^^^^^^^^

The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, which can be found in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, which can be found in
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, the configs for which can be found in

The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, which can be found in
`qwen2/knowledge_distillation_single_device <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/recipes/configs/qwen2/knowledge_distillation_single_device.yaml>`_
config. Here we see that alpaca_cleaned_dataset only improves truthful_qa performance and drops the metrics for the other evaluation tasks.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
config. Here we see that alpaca_cleaned_dataset only improves truthful_qa performance and drops the metrics for the other evaluation tasks.
config. Here we see that training on the alpaca cleaned dataset only improves truthful_qa performance and drops the metrics for the other evaluation tasks.

@SalmanMohammadi
Copy link
Collaborator

Thanks so much for adding this. I've been loosely following along your KD work and it's really impressive.

Lots of small nits but this is an awesome tutorial. I especially love all the empirical results.

return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

There are some details omitted to simplify the computation, but if you'd like to know more,
you can see the implementation in `ForwardKLLoss <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/torchtune/modules/loss/kd_losses.py>`_.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it perhaps worth mentioning that we actually use ForwardKLWithChunkedOutputLoss by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a note that we used the chunked loss to save memory.

What is Knowledge Distillation?
-------------------------------

`Knowledge Distillation <https://arxiv.org/pdf/1503.02531>`_ is is a widely used compression technique
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`Knowledge Distillation <https://arxiv.org/pdf/1503.02531>`_ is is a widely used compression technique
`Knowledge Distillation <https://arxiv.org/pdf/1503.02531>`_ is a widely used compression technique

How does Knowledge Distillation work?
-------------------------------------

Knowledge is transferred from the teacher to student model by training it on a transfer set and where the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Knowledge is transferred from the teacher to student model by training it on a transfer set and where the
Knowledge is transferred from the teacher to student model by training it on a transfer set where the

Comment on lines 81 to 82
you can see the implementation in :class:`torchtune.modules.loss.ForwardKLLoss`.
By default, the KD configs use :class:`torchtune.modules.loss.ForwardKLWithChunkedOutputLoss` to reduce memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to add these APIs to the rst file here. This way it'll render as an actual pointer to our API ref

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. Thanks for the pointer.

@SalmanMohammadi
Copy link
Collaborator

LGTM - please acquire Evan's blessing before merging.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙌

@ebsmothers ebsmothers merged commit 7c4c629 into pytorch:main Sep 27, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants