-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QLoRA tutorial #693
QLoRA tutorial #693
Conversation
ghstack-source-id: aa906a002fccbc9e80acfe3c4848febe23d5071f Pull Request resolved: #590
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/693
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit fa6b0a3 with merge base 5b0dc57 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
accuracy. | ||
|
||
The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat | ||
type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. TorchTune uses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤯
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve | ||
accuracy. | ||
|
||
The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I feel like you're linking to the paper too many times, usually I just do it once upon introduction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nah, I think this is fine. you don't know which paragraph someone will start reading at if they're skimming or jumping around
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should link every other reference to the paper to keep people on their toes.
quantization is done through the method highlighted in the original `QLoRA paper <https://arxiv.org/abs/2305.14314>`_. Adapter | ||
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve | ||
accuracy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing that'd be nice: basically take the diagram in the LoRA tutorial demonstrating full finetune -> LoRA and add one more for LoRA -> QLoRA. (The diagrams take a bit of time so I feel this is more of a nice-to-have at this point.) But if you're interested let me know and I can dig up the original
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely interested! Will punt this out to a follow up though.
accuracy. | ||
|
||
The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat | ||
type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. TorchTune uses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth giving some more details on either of these two optimizations, or do you think it's too in the weeds?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO its too in the weeds and not worth it to just re explain if its in the paper. Can add a line directing folks to the paper, but IMO its already clear enough to read the paper for this sort of detail
in a typical LoRA training flow. | ||
|
||
To achieve this, when using TorchTune's ``qlora_llama2_7b`` builder, we automatically register a hook, :code:`reparametrize_as_dtype_state_dict_post_hook`, | ||
that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these | ||
converted tensors to the CPU. This offloading is to avoid peaking memory by maintaining an entire bf16/fp32 copy of the ``state_dict`` | ||
on GPU, which could lead to potential OOMs during checkpoint save, even if memory is appropriately managed during | ||
training. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I think this whole section could be more easily illuminated via a code block with e.g. .state_dict() with peak memory printed without the state dict hook, then add the hook and do the same thing.
As well as during training: | ||
|
||
.. code-block:: python | ||
|
||
Memory Stats:: | ||
GPU peak memory allocation: 14.40 GB | ||
GPU peak memory reserved: 15.57 GB | ||
GPU peak memory active: 14.40 GB | ||
|
||
|
||
Comparing to the memory usage during model initialization for QLoRA, we see about a 35% decrease in peak memory reserved: | ||
|
||
.. code-block:: python | ||
|
||
Memory Stats after model init:: | ||
GPU peak memory allocation: 7.36 GB | ||
GPU peak memory reserved: 9.13 GB | ||
GPU peak memory active: 7.36 GB | ||
|
||
As well as a 40% decrease in peak memory reserved during training: | ||
|
||
.. code-block:: python | ||
|
||
Memory Stats:: | ||
GPU peak memory allocation: 5.54 GB | ||
GPU peak memory reserved: 9.29 GB | ||
GPU peak memory active: 5.54 GB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit torn here.. while I think it's nice to print the raw output that someone will see in the logs, personally I would just do it one time, e.g. "you will see something like this" and then put it in a table (QLoRA, LoRA) x (peak init memory, peak training memory). Then it's a bit more digestible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added table
----------------------------------------- | ||
|
||
Putting it all together, we can now finetune a model using TorchTune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_, | ||
with a `<QLoRA configuration https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This hyperlink format is so annoying lol
with a `<QLoRA configuration https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_. | |
with a `QLoRA configuration <https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_. |
|
||
.. code-block:: bash | ||
|
||
tune run lora_finetune_single_device --config recipes/configs/llama2/7B_qlora_single_device.yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be able to scrap the recipes/configs/
part here, right? (Same comment for the LoRA command below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve). | ||
|
||
.. image:: /_static/img/qlora_experiment.png |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw you can also add explicit labels to the two lines in the figure, e.g. iconic-pasma-57 -> LoRA and azure-bird-56 -> QLoRA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be annoying, but can you filter the x-axis to [0, 1000] or something in wandb and reupload? Otherwise it looks weird that one is running longer.
`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable additional | ||
memory efficiency on top of LoRA. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: multiple instances of "on top of LoRA" is repetitive
|
||
In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of | ||
`LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll | ||
walk through how QLoRA can be utilized within TorchTune to finetune a Llama2-7b model in < 10 GB of memory. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you should find and replace TorchTune -> torchtune as we have done it everywhere else since after you first opened this PR
`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable further | ||
memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are | ||
low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of | ||
the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit
the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and | |
the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and |
the `NF4Tensor <https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py#L153>`_ abstraction from the `TorchAO library <https://github.com/pytorch-labs/ao>`_ to build QLoRA components as specified in the paper. | ||
`TorchAO library <https://github.com/pytorch-labs/ao>`_ is a PyTorch-native library that allows you to quantize and prune your models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we want TorchAO -> torchao
Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models. | ||
To integrate well with TorchTune's :ref:`checkpointing <checkpointing_label>`, we need to convert ``NF4Tensors`` back to their | ||
original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within | ||
TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. Also might remove checkpoint conversion since idk that it's really an ecosystem thing in the same way the other items are
TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done | |
TorchTune and beyond (e.g. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done |
converted tensors to the CPU. This offloading is to avoid peaking memory by maintaining an entire bf16/fp32 copy of the ``state_dict`` | ||
on GPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this sentence could be a little unclear, seemingly implying that the way we avoid peaking memory is by maintaining an entire bf16/fp32 copy on GPU.
|
||
.. code-block:: bash | ||
|
||
tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True | |
tune run lora_finetune_single_device --config llama2/7B_qlora_single_device.yaml compile=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh shoot, .yaml is not correct we need to remove that as well.
|
||
1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s | ||
|
||
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove this now that you've added the legend
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve). | |
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below. |
In the next section, we'll learn about how to use QLoRA in TorchTune to build a QLoRA quantized Llama2-7b model, as well as some nuances around | ||
checkpointing that are important to be aware of to avoid spiking memory usage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer applicable
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #693 +/- ##
=======================================
Coverage ? 26.69%
=======================================
Files ? 145
Lines ? 6147
Branches ? 0
=======================================
Hits ? 1641
Misses ? 4506
Partials ? 0 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK left a handful of other comments, please make sure they're addressed before landing. Modulo that I think this is looking good!
Context
Changelog
Test plan