Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only merge model weights in LoRA recipe when save_adapter_weights_only=False #1476

Merged
merged 11 commits into from
Sep 15, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Sep 3, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

When saving LoRA checkpoints, we only need to copy and merge the full model state dict when save_adapter_weights_only=False. For distributed recipes, this change will only skip the merging logic.

I've also added tests for resuming recipe state with LoRA DPO.

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Sep 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1476

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4991014 with merge base be8f1e7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 3, 2024
@codecov-commenter
Copy link

codecov-commenter commented Sep 15, 2024

Codecov Report

Attention: Patch coverage is 23.33333% with 69 lines in your changes missing coverage. Please review.

Project coverage is 27.00%. Comparing base (60cf96f) to head (05620fe).
Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
tests/recipes/test_lora_dpo_single_device.py 27.27% 48 Missing ⚠️
recipes/lora_dpo_single_device.py 0.00% 6 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 6 Missing ⚠️
recipes/lora_dpo_distributed.py 0.00% 3 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 3 Missing ⚠️
tests/recipes/utils.py 25.00% 3 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (60cf96f) and HEAD (05620fe). Click for more details.

HEAD has 3 uploads less than BASE
Flag BASE (60cf96f) HEAD (05620fe)
6 3
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1476       +/-   ##
===========================================
- Coverage   73.12%   27.00%   -46.13%     
===========================================
  Files         289      290        +1     
  Lines       14175    14252       +77     
===========================================
- Hits        10366     3849     -6517     
- Misses       3809    10403     +6594     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're a hero

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just afraid someone would ask me to run the recipe manually, so this felt like less effort.

Comment on lines 555 to 559
adapter_state_dict = {
k: v
for k, v in self._model.state_dict().items()
if adapter_key_filter(k)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this still materialize the entire state dict though?

alpha=self._lora_alpha,
)
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
if not self._save_adapter_weights_only:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @pbontrager had looked at something with this previously, but we still create the full state dict on CPU in this case then, right? To actually reduce the memory footprint to just the adapter weights we'd need to change the call to get_full_model_state_dict above (btw it's ok to say "let's do this in a follow-up")

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple questions around whether we can get around materializing the full model state dict when save_adapter_weights_only=True. (I think it'll be trickier for the distributed recipe, but would be nice to sort out for single-device.) Other than that looks good!

@SalmanMohammadi SalmanMohammadi merged commit dcb9531 into pytorch:main Sep 15, 2024
17 checks passed
@SalmanMohammadi SalmanMohammadi deleted the save_adapter_only_fix branch September 15, 2024 20:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants