Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp32 support for QLoRA #595

Merged
merged 16 commits into from
Apr 2, 2024
Merged

Add fp32 support for QLoRA #595

merged 16 commits into from
Apr 2, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Mar 26, 2024

Context

  • QLoRA is currently coupled to bf16, but some older HW types don't support bf16, and we'd ideally like to enable at least one memory efficient finetuning solution for these HW arches. For example, T4s are 16GB and don't support bf16. This PR enables QLoRA to run compute + checkpoint in fp32 instead of bf16, to eliminate the hard coupling of bf16 to QLoRA training.

Changelog

  • Remove bf16 assumptions from nf4
  • Remove bf16 assumptions from LoRALinear and upstream
  • Generalize a few functions to be less coupled to bf16
  • Fix tests

NOTE

  • We've currently forked over the LinearNF4 from torchao while changes to LinearNF4 land in ao are in progress. Will revert back to using ao's implementation asap, but we need the changes in this forked version to decouple support from fp32. This is a temporary (~days) mitigation.

Test plan

  • Modified unittests - computation in fp32, checkpointing, parity is covered. fp32 coverage for qlora is at the same coverage as bf16.
  • Verified manually that gradients are computed in fp32 (via inspecting .grad field)
  • Checkpoints are saved in fp32:
image
  • Run recipe: tune lora_finetune_single_device --config llama2/7B_qlora_single_device dtype=fp32 epochs=1
  • Loss is comparable to QLoRA and LoRA-bf16 (see QLoRA #478 for those curves):
image
  • Memory - increase over QLoRA bf16 is +20% peak memory allocated, +3% reserved memory
Memory Stats::
 GPU peak memory allocation: 6.98 GB
 GPU peak memory reserved: 9.57 GB
 GPU peak memory active: 6.98 GB
  • Eval result:

ghstack-source-id: aa906a002fccbc9e80acfe3c4848febe23d5071f
Pull Request resolved: #590
Copy link

pytorch-bot bot commented Mar 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/595

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8fdb0af with merge base 2940941 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2024
@rohan-varma rohan-varma marked this pull request as draft March 26, 2024 18:22
@rohan-varma rohan-varma marked this pull request as ready for review March 26, 2024 20:59
@rohan-varma rohan-varma changed the title Nf32 Add fp32 support for QLoRA Mar 26, 2024
partial(reparametrize_as_bf16_state_dict_post_hook, offload_to_cpu=True)
partial(
reparametrize_as_dtype_state_dict_post_hook,
# TODO this is clowny, figure out a better way to get what precision the rest
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I don't really see a better way to do this

@@ -9,6 +9,8 @@
import torch.nn.functional as F

from torch import nn, Tensor

# from torchtune.modules.low_precision.nf4_linear import _linear_nf4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@rohan-varma rohan-varma merged commit 2ac4258 into main Apr 2, 2024
20 checks passed
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
@joecummings joecummings deleted the nf32 branch April 11, 2024 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants