-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Knowledge distillation tutorial #1698
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1698
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4ab1789 with merge base a899da2 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
||
This guide will teach you about knowledge distillation (KD) and show you how you can use torchtune to distill a Llama3.1 8B model into Llama3.2 1B. | ||
If you already know what knowledge distillation is and want to get straight to running your own distillation in torchtune, | ||
you can jump to knowledge distillation recipe in torchtune, `knowledge_distillation_single_device.py <https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
you can jump to knowledge distillation recipe in torchtune, `knowledge_distillation_single_device.py <https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py>`_. | |
you can jump to the knowledge distillation recipe in torchtune, `knowledge_distillation_single_device.py <https://github.com/pytorch/torchtune/blob/main/recipes/knowledge_distillation_single_device.py>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it might be better to just reference the latter parts of this tutorial for people who want to jump ahead rather than pointing to the recipe file,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to link the tutorial section instead of recipe.
|
||
.. image:: /_static/img/kd-simplified.png | ||
|
||
The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with | |
The total loss can be configured in many ways. The default KD config in torchtune combines the cross-entropy (CE) loss with |
.. image:: /_static/img/kd-simplified.png | ||
|
||
The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with | ||
forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss, | |
the forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss, |
|
||
The total loss can be configured in many ways. The default KD config in torchtune combines CE loss with | ||
forward `Kullback-Leibler (KL) divergence <https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_ loss, | ||
which is used in standard KD approaches.Forward KL divergence aims to minimize the difference by forcing the student |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which is used in standard KD approaches.Forward KL divergence aims to minimize the difference by forcing the student | |
which is used in standard KD approaches. Forward KL divergence aims to minimize the difference by forcing the student's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which difference are you referring to? : )
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) | ||
|
||
There are some details omitted to simplify the computation, but if you'd like to know more, | ||
you can see the implementation in `ForwardKLLoss <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/torchtune/modules/loss/kd_losses.py>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be able to do
you can see the implementation in `ForwardKLLoss <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/torchtune/modules/loss/kd_losses.py>`_. | |
you can see the implementation in :class:`torchtune.modules.loss.ForwardKLLoss`. |
With torchtune, we can easily apply knowledge distillation to Llama3, as well as other LLM model families. | ||
Let's take a look at how you could distill a model using torchtune's `KD recipe <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/recipes/knowledge_distillation_single_device.py>`_. | ||
|
||
First, make sure that you have downloaded the Llama3 weights. For this example, we'll use the Llama3.1-8B as teacher and Llama3.2-1B as student. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit nit:
First, make sure that you have downloaded the Llama3 weights. For this example, we'll use the Llama3.1-8B as teacher and Llama3.2-1B as student. | |
First, make sure that you have downloaded all the model weights. For this example, we'll use the Llama3.1-8B as teacher and Llama3.2-1B as student. |
|
||
tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" --hf_token <HF_TOKEN> | ||
|
||
Then, we will fine-tune the teacher model with using LoRA. Based on our experiments and previous work, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit nit (ignore if you like):
Then, we will fine-tune the teacher model with using LoRA. Based on our experiments and previous work, | |
Then, we will fine-tune the teacher model with using LoRA. Based on our experiments and previous work, |
and `commonsense_qa <https://github.com/EleutherAI/lm-evaluation-harness/tree/b62b9bd/lm_eval/tasks/commonsense_qa>`_ | ||
through `EleutherEval <https://github.com/EleutherAI/lm-evaluation-harness/tree/main>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and `commonsense_qa <https://github.com/EleutherAI/lm-evaluation-harness/tree/b62b9bd/lm_eval/tasks/commonsense_qa>`_ | |
through `EleutherEval <https://github.com/EleutherAI/lm-evaluation-harness/tree/main>`_. | |
and `commonsense_qa <https://github.com/EleutherAI/lm-evaluation-harness/tree/b62b9bd/lm_eval/tasks/commonsense_qa>`_ tasks | |
through the EleutherAI `LM evaluation harness <https://github.com/EleutherAI/lm-evaluation-harness/tree/main>`_. |
|
||
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device | ||
|
||
Ablation studies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whole section is awesome
Hyperparameter tuning: learning rate | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
By default, the config has the learning rate as 3e-4, same as LoRA configs. For these experiments, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By default, the config has the learning rate as 3e-4, same as LoRA configs. For these experiments, | |
By default, the config has the learning rate as ``3e-4`` which is the same as the LoRA configs. For these experiments, |
============================ | ||
Distilling Llama3 8B into 1B | ||
============================ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
============================ | |
Distilling Llama3 8B into 1B | |
============================ | |
================================================== | |
Distilling Llama3.1 8B into Llama3.2 1B using Knowledge Distillation | |
================================================== |
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
By default, the config has the learning rate as 3e-4, same as LoRA configs. For these experiments, | ||
we changed the learning rate from as high as 1e-3 to as low as 1e-5. To change the learning rate, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mega nit: should be able to use latexify this using something like :math:1e^{-1}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed all to use math.
|
||
.. code-block:: bash | ||
|
||
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=[LR] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's clearer to provide a more concrete (vs general) example here
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=[LR] | |
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device optimizer.lr=1e-3 |
Hyperparameter tuning: KD ratio | ||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
In the config, we have the kd_ratio as 0.5, which gives even weightings to both the class and KD loss. In these experiments, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the config, we have the kd_ratio as 0.5, which gives even weightings to both the class and KD loss. In these experiments, | |
In the config, we have the ``kd_ratio`` as 0.5, which gives even weightings to both the class and KD loss. In these experiments, |
|
||
.. code-block:: bash | ||
|
||
tune run knowledge_distillation_single_device --config llama3_2/knowledge_distillation_single_device kd_ratio=[KD_RATIO] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above
Qwen2 1.5B to 0.5B | ||
^^^^^^^^^^^^^^^^^^ | ||
|
||
The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mega nit:
The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of | |
The KD recipe can also be applied to different model families. Here we look at the effect of KD when the number of |
^^^^^^^^^^^^^^^^^^ | ||
|
||
The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of | ||
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, which can be found in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, which can be found in | |
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, the configs for which can be found in |
The KD recipe can also be applied to different model families as well. Here we look at the effect of KD when the number of | ||
parameters between the teacher and student models are closer. For this experiment, we used Qwen2 1.5B and Qwen2 0.5B, which can be found in | ||
`qwen2/knowledge_distillation_single_device <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/recipes/configs/qwen2/knowledge_distillation_single_device.yaml>`_ | ||
config. Here we see that alpaca_cleaned_dataset only improves truthful_qa performance and drops the metrics for the other evaluation tasks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config. Here we see that alpaca_cleaned_dataset only improves truthful_qa performance and drops the metrics for the other evaluation tasks. | |
config. Here we see that training on the alpaca cleaned dataset only improves truthful_qa performance and drops the metrics for the other evaluation tasks. |
Thanks so much for adding this. I've been loosely following along your KD work and it's really impressive. Lots of small nits but this is an awesome tutorial. I especially love all the empirical results. |
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) | ||
|
||
There are some details omitted to simplify the computation, but if you'd like to know more, | ||
you can see the implementation in `ForwardKLLoss <https://github.com/pytorch/torchtune/blob/4234b78b914af23384ce0348f564e2119d107a96/torchtune/modules/loss/kd_losses.py>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it perhaps worth mentioning that we actually use ForwardKLWithChunkedOutputLoss by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a note that we used the chunked loss to save memory.
What is Knowledge Distillation? | ||
------------------------------- | ||
|
||
`Knowledge Distillation <https://arxiv.org/pdf/1503.02531>`_ is is a widely used compression technique |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`Knowledge Distillation <https://arxiv.org/pdf/1503.02531>`_ is is a widely used compression technique | |
`Knowledge Distillation <https://arxiv.org/pdf/1503.02531>`_ is a widely used compression technique |
How does Knowledge Distillation work? | ||
------------------------------------- | ||
|
||
Knowledge is transferred from the teacher to student model by training it on a transfer set and where the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Knowledge is transferred from the teacher to student model by training it on a transfer set and where the | |
Knowledge is transferred from the teacher to student model by training it on a transfer set where the |
you can see the implementation in :class:`torchtune.modules.loss.ForwardKLLoss`. | ||
By default, the KD configs use :class:`torchtune.modules.loss.ForwardKLWithChunkedOutputLoss` to reduce memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you need to add these APIs to the rst file here. This way it'll render as an actual pointer to our API ref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. Thanks for the pointer.
LGTM - please acquire Evan's blessing before merging. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙌
Context
What is the purpose of this PR? Is it to
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example