Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLoRA tutorial #693

Merged
merged 32 commits into from
Apr 15, 2024
Merged

QLoRA tutorial #693

merged 32 commits into from
Apr 15, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Apr 11, 2024

Context

  • Adds a QLoRA tutorial to accompany the LoRA tutorial in torchtune. Discussed with @ebsmothers, the nuances and details in QLoRA are enough for it to warrant its own tutorial. Will also link LoRA tutorial to QLoRA and backlink QLoRA to LoRA tutorial.

Changelog

  • ...

Test plan

image

Copy link

pytorch-bot bot commented Apr 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/693

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit fa6b0a3 with merge base 5b0dc57 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 11, 2024
@rohan-varma rohan-varma marked this pull request as draft April 11, 2024 22:27
@rohan-varma rohan-varma marked this pull request as ready for review April 12, 2024 23:16
@rohan-varma rohan-varma changed the title [WIP] QLoRA tutorial QLoRA tutorial Apr 12, 2024
docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved
docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved
docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved
accuracy.

The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat
type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. TorchTune uses
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve
accuracy.

The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel like you're linking to the paper too many times, usually I just do it once upon introduction

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah, I think this is fine. you don't know which paragraph someone will start reading at if they're skimming or jumping around

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should link every other reference to the paper to keep people on their toes.

Comment on lines 39 to 41
quantization is done through the method highlighted in the original `QLoRA paper <https://arxiv.org/abs/2305.14314>`_. Adapter
parameters are still held in the original precision, and activations, gradients, and optimizer states still exist in the higher precision to preserve
accuracy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing that'd be nice: basically take the diagram in the LoRA tutorial demonstrating full finetune -> LoRA and add one more for LoRA -> QLoRA. (The diagrams take a bit of time so I feel this is more of a nice-to-have at this point.) But if you're interested let me know and I can dig up the original

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely interested! Will punt this out to a follow up though.

docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved
accuracy.

The `QLoRA paper <https://arxiv.org/abs/2305.14314>`_ introduces two key abstractions to decrease memory usage and avoid accuracy degradation: the bespoke 4-bit NormatFloat
type, and a double quantization method that quantizes the quantization parameters themselves to save even more memory. TorchTune uses
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth giving some more details on either of these two optimizations, or do you think it's too in the weeds?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO its too in the weeds and not worth it to just re explain if its in the paper. Can add a line directing folks to the paper, but IMO its already clear enough to read the paper for this sort of detail

docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved
Comment on lines 179 to 185
in a typical LoRA training flow.

To achieve this, when using TorchTune's ``qlora_llama2_7b`` builder, we automatically register a hook, :code:`reparametrize_as_dtype_state_dict_post_hook`,
that runs after calling ``.state_dict()`` on the top level model. This hook converts ``NF4Tensors`` back to their original precision, while also offloading these
converted tensors to the CPU. This offloading is to avoid peaking memory by maintaining an entire bf16/fp32 copy of the ``state_dict``
on GPU, which could lead to potential OOMs during checkpoint save, even if memory is appropriately managed during
training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I think this whole section could be more easily illuminated via a code block with e.g. .state_dict() with peak memory printed without the state dict hook, then add the hook and do the same thing.

Comment on lines 226 to 252
As well as during training:

.. code-block:: python

Memory Stats::
GPU peak memory allocation: 14.40 GB
GPU peak memory reserved: 15.57 GB
GPU peak memory active: 14.40 GB


Comparing to the memory usage during model initialization for QLoRA, we see about a 35% decrease in peak memory reserved:

.. code-block:: python

Memory Stats after model init::
GPU peak memory allocation: 7.36 GB
GPU peak memory reserved: 9.13 GB
GPU peak memory active: 7.36 GB

As well as a 40% decrease in peak memory reserved during training:

.. code-block:: python

Memory Stats::
GPU peak memory allocation: 5.54 GB
GPU peak memory reserved: 9.29 GB
GPU peak memory active: 5.54 GB
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit torn here.. while I think it's nice to print the raw output that someone will see in the logs, personally I would just do it one time, e.g. "you will see something like this" and then put it in a table (QLoRA, LoRA) x (peak init memory, peak training memory). Then it's a bit more digestible

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added table

-----------------------------------------

Putting it all together, we can now finetune a model using TorchTune's `LoRA recipe <https://github.com/pytorch/torchtune/blob/48626d19d2108f92c749411fbd5f0ff140023a25/recipes/lora_finetune.py>`_,
with a `<QLoRA configuration https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hyperlink format is so annoying lol

Suggested change
with a `<QLoRA configuration https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_.
with a `QLoRA configuration <https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/7B_qlora_single_device.yaml>`_.


.. code-block:: bash

tune run lora_finetune_single_device --config recipes/configs/llama2/7B_qlora_single_device.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be able to scrap the recipes/configs/ part here, right? (Same comment for the LoRA command below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

docs/source/examples/qlora_finetune.rst Outdated Show resolved Hide resolved

A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve).

.. image:: /_static/img/qlora_experiment.png
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw you can also add explicit labels to the two lines in the figure, e.g. iconic-pasma-57 -> LoRA and azure-bird-56 -> QLoRA

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to be annoying, but can you filter the x-axis to [0, 1000] or something in wandb and reupload? Otherwise it looks weird that one is running longer.

Comment on lines 31 to 32
`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable additional
memory efficiency on top of LoRA. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: multiple instances of "on top of LoRA" is repetitive


In this tutorial, we'll learn about `QLoRA <https://arxiv.org/abs/2305.14314>`_, an enhancement on top of
`LoRA <https://arxiv.org/abs/2106.09685>`_ that maintains frozen model parameters in 4-bit quantized precision, thereby reducing memory usage. We'll
walk through how QLoRA can be utilized within TorchTune to finetune a Llama2-7b model in < 10 GB of memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you should find and replace TorchTune -> torchtune as we have done it everywhere else since after you first opened this PR

`QLoRA <https://arxiv.org/abs/2305.14314>`_ builds on top of `LoRA <https://arxiv.org/abs/2106.09685>`_ to enable further
memory savings. In LoRA, model parameters can be thought of as existing in two partitions: adapters, which are
low-rank matrices added to different layers of a neural network, and base model parameters, which are parameters that are part of
the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
the original model. In vanilla LoRA style training, both these parameters are held in the same precision (typically fp32 or bf16), and
the original model. In vanilla LoRA-style training, both these parameters are held in the same precision (typically fp32 or bf16), and

Comment on lines 45 to 46
the `NF4Tensor <https://github.com/pytorch-labs/ao/blob/b9beaf351e27133d189b57d6fa725b1a7824a457/torchao/dtypes/nf4tensor.py#L153>`_ abstraction from the `TorchAO library <https://github.com/pytorch-labs/ao>`_ to build QLoRA components as specified in the paper.
`TorchAO library <https://github.com/pytorch-labs/ao>`_ is a PyTorch-native library that allows you to quantize and prune your models.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we want TorchAO -> torchao

Next, there are a couple of details essential to checkpointing (i.e. ``state_dict``) of QLoRA-enabled models.
To integrate well with TorchTune's :ref:`checkpointing <checkpointing_label>`, we need to convert ``NF4Tensors`` back to their
original precision (generally fp32/bf16). This allows QLoRA-trained checkpoints to interoperate well with the rest of the ecosystem, within
TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. Also might remove checkpoint conversion since idk that it's really an ecosystem thing in the same way the other items are

Suggested change
TorchTune and beyond (i.e. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done
TorchTune and beyond (e.g. checkpoint conversion, post-training quantization, evaluation, inference). This conversion process also allows LoRA adapter weights to be merged back into the base model as done

Comment on lines 118 to 119
converted tensors to the CPU. This offloading is to avoid peaking memory by maintaining an entire bf16/fp32 copy of the ``state_dict``
on GPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this sentence could be a little unclear, seemingly implying that the way we avoid peaking memory is by maintaining an entire bf16/fp32 copy on GPU.


.. code-block:: bash

tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tune run lora_finetune_single_device --config llama2/7B_lora_single_device.yaml compile=True
tune run lora_finetune_single_device --config llama2/7B_qlora_single_device.yaml compile=True

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh shoot, .yaml is not correct we need to remove that as well.


1|228|Loss: 0.8158286809921265: 1%| | 228/25880 [11:59<1:48:16, 3.95it/s

A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove this now that you've added the legend

Suggested change
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below (purple being the QLoRA loss curve).
A comparison of the smoothed loss curves between QLoRA and LoRA can be seen below.

Comment on lines 274 to 275
In the next section, we'll learn about how to use QLoRA in TorchTune to build a QLoRA quantized Llama2-7b model, as well as some nuances around
checkpointing that are important to be aware of to avoid spiking memory usage.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer applicable

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

❗ No coverage uploaded for pull request base (main@5b0dc57). Click here to learn what that means.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #693   +/-   ##
=======================================
  Coverage        ?   26.69%           
=======================================
  Files           ?      145           
  Lines           ?     6147           
  Branches        ?        0           
=======================================
  Hits            ?     1641           
  Misses          ?     4506           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK left a handful of other comments, please make sure they're addressed before landing. Modulo that I think this is looking good!

@rohan-varma rohan-varma merged commit 0914d5c into main Apr 15, 2024
27 checks passed
@joecummings joecummings deleted the qlora_tutorial branch April 16, 2024 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants