Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLoRA with Llama 3.1 405B #1232

Merged
merged 33 commits into from
Sep 17, 2024
Merged

Conversation

pbontrager
Copy link
Contributor

@pbontrager pbontrager commented Jul 26, 2024

Edit: this PR has been resuscitated. The main change since the previous version is adding a new way to pass checkpoint files. Now we can just do

  checkpoint_files:
    formatted_string: model-{}-of-{}.safetensors
    max_filename: 00191

without having to copy-paste 191 rows of config. We should integrate this into all our 70B configs too when we get a chance.

Test plan

Added some new tests for the FormattedCheckpointFiles class.

pytest tests/torchtune/training/checkpointing/test_checkpointer_utils.py
...
====== 11 passed, 1 warning in 0.14s ===================

E2E test:

tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_1/405B_qlora \
max_steps_per_epoch=2 gradient_accumulation_steps=2 epochs=2
...
1|2|Loss: 1.7182660102844238: 100%|████████████████████████████| 2/2 [02:39<00:00, 79.59s/it]saving checkpoint
INFO:torchtune.utils._logging:Adapter checkpoint of size 1.25 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/adapter_0.pt
INFO:torchtune.utils._logging:Adapter checkpoint of size 1.25 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/adapter_model.bin
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/adapter_config.json
INFO:torchtune.utils._logging:Recipe checkpoint of size 2.50 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/recipe_state.pt
1|2|Loss: 1.7182660102844238: 100%|██████████████████████████████| 2/2 [02:53<00:00, 86.50s/it]
  0%|                                                                                                                                                                                                                                                                     | 0/2 [00:00<?, ?it/ssaving checkpoint971073150635: 100%|████████████████████████| 2/2 [02:38<00:00, 79.58s/it]
INFO:torchtune.utils._logging:Adapter checkpoint of size 1.25 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/adapter_1.pt
INFO:torchtune.utils._logging:Adapter checkpoint of size 1.25 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/adapter_model.bin
INFO:torchtune.utils._logging:Adapter checkpoint of size 0.00 GB saved to /tmp/Meta-Llama-3.1-405B-Instruct/adapter_config.json
INFO:torchtune.utils._logging:Saving final epoch checkpoint.
INFO:torchtune.utils._logging:Please note that you have set adapter_only=True, so only adapter weights will be saved.You need to merge the adapter weights into your base model for further use. See FullModelHFCheckpointer.save_checkpoint for more details.
2|4|Loss: 2.5931971073150635: 100%|███████████████████████████████████| 2/2 [02:45<00:00, 82.55s/it]

Setting aside the fact that our logs are annoying here, this works.

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

This PR adds a Lllama 405B QLoRA config for the lora_finetune_fsdp2 recipe. This config requires 8 x A100s to run. The config successfully finetunes Llama 405B on the alpaca dataset as an example. There are several caveats for getting this large model to fit on a single node:

  • Model checkpointing causes NCCL timeout errors, so we only save the adapter weights that can be merged afterward
  • Training is slow, at around 10 minutes for 16 gradient steps (this should improve with compile support)
  • Can only fit < 4k context length

This PR was jointly made with me and @joecummings

Changelog

  • New config
  • New model builders
  • Updated Docs for new model builders
  • Updated fsdp2 recipe save checkpoint to support only saving adapter weights

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Test Memory + Checkpointing
tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora log_peak_memory_stats=True dataset.packed=True dataset.max_seq_len=2048 max_steps_per_epoch=2 epochs=2

Test Model Loss
tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora

Copy link

pytorch-bot bot commented Jul 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1232

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8b1a33d with merge base 7045e96 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 26, 2024
@pbontrager pbontrager changed the title Meta Llama 3.1 405B QLoRA with Llama 3.1 405B Jul 26, 2024
@@ -63,6 +63,27 @@ def llama3_1_70b() -> TransformerDecoder:
)


def llama3_1_405b() -> TransformerDecoder:
"""
Builder for creating a Llama3 model initialized w/ the default 405B parameter values.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Llama3.1
here and in the return

Args:
model (FSDPModule): wrapped module
is_rank_zero (bool): flag to check if the process is on rank 0
trainable_only (bool): flag to only return state dict of trainable parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will this work with #1220?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I thought just saving the adapter would be enough, but for the 405B model, consolidating the state_dict is also prohibitively expensive. This update can be added to the other recipes after this lands.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No NCCL timeouts?

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jul 26, 2024

Couple of small comments (looks like CI is failing too?). This is huge.

Two points as an an aside:

  • I see this is a dev recipe right? What would bring it into the stable recipes?

  • Following from above, would something like warrant a writeup to help get the good word out about the amazing work you guys have been doing here, and how torchtune enables feats like this?

@joecummings
Copy link
Contributor

  • I see this is a dev recipe right? What would bring it into the stable recipes?

The only reason this is dev is that FSDP2 support just dropped in PyTorch stable a few days ago, so we want to play around with it more before moving to stable. We need FSDP2 support b/c it plays nicely with QLoRA.

  • Following from above, would something like this would warrant a writeup to help get the good word out about the amazing work you guys have been doing here, and how torchtune enables feats like this?

We'd love this - I can definitely promote on Discord and Twitter and start on a formal blog writeup, but would appreciate any help here :)

@SalmanMohammadi
Copy link
Collaborator

We'd love this - I can definitely promote on Discord and Twitter and start on a formal blog writeup, but would appreciate any help here :)

I'd love to! Happy to help out here in any capacity - reviewing to help fleshing out a rough sketch.

@@ -7,10 +7,13 @@
from ._component_builders import llama3_1, lora_llama3_1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linter unhappy : ( need to add to __all__?

Copy link
Contributor Author

@pbontrager pbontrager Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in __all__, no idea what's triggering this. The pre-commit flake8 passes too

@@ -19,9 +22,12 @@
"llama3_1",
"llama3_1_8b",
"llama3_1_70b",
"llama3_1_405b"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a comma here, friend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:face_palm:

"lora_llama3_1",
"lora_llama3_1_8b",
"lora_llama3_1_70b",
"lora_llama3_1_405b"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/
checkpoint_files: [
Copy link
Contributor

@ebsmothers ebsmothers Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guys I think the time has come.. we need to refactor checkpoint_files field to optionally support the combination of (a) number of files and (b) strf filename format (or something like that). This is like 5 minutes of work and will save us from abominations like this. cc @joecummings

@SalmanMohammadi
Copy link
Collaborator

pydoclint................................................................Failed
- hook id: pydoclint
- exit code: 1

Skipping files that match this pattern: tests/torchtune/models/(\w+)/scripts/
recipes/dev/lora_finetune_fsdp2.py
torchtune/_recipe_registry.py
torchtune/models/llama3_1/__init__.py
torchtune/utils/_checkpointing/_checkpointer.py
torchtune/utils/_distributed.py
torchtune/utils/_distributed.py
    361: DOC201: Function `get_full_model_state_dict` does not have a return section in docstring
    361: DOC501: Function `get_full_model_state_dict` has "raise" statements, but the docstring does not have a "Raises" section
Loading config from user-specified .toml file: pyproject.toml
Found options defined in pyproject.toml:
{'style': 'google', 'check_return_types': 'False', 'exclude': 'tests/torchtune/models/(\\w+)/scripts/'}

welcome back to hell, brother

@joecummings
Copy link
Contributor

joecummings commented Jul 30, 2024

Currently hitting this error with compile, no speedups:

[rank1]:W0730 05:54:09.440000 86176 torch/_dynamo/convert_frame.py:795] [8/8] torch._dynamo hit config.cache_size_limit (8)
[rank1]:W0730 05:54:09.440000 86176 torch/_dynamo/convert_frame.py:795] [8/8]    function: 'forward' (/home/jrcummings/.conda/envs/joe-torchtune/lib/python3.11/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py:146)
[rank1]:W0730 05:54:09.440000 86176 torch/_dynamo/convert_frame.py:795] [8/8]    last reason: 8/0: ___check_type_id(L['self'], 206876688)
[rank1]:W0730 05:54:09.440000 86176 torch/_dynamo/convert_frame.py:795] [8/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank1]:W0730 05:54:09.440000 86176 torch/_dynamo/convert_frame.py:795] [8/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.

This happens on all ranks.

cc @bdhirsh @weifengpy

@bdhirsh
Copy link

bdhirsh commented Aug 6, 2024

@joecummings it's a bit hard to tell exactly what we're recompiling on from the logs (since it just includes a type id). But one known cause of recompilations using DTensor today is that we overspecialize on all of the DTensor metadata (Placement/DeviceMesh).

This PR fixes it, so once it lands can you try again? (and we can stare at the recompiles some more if that doesn't fix it) pytorch/pytorch#124401

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not allow a regex pattern for files

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test???????!?!??!?!?!?!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what the hell does strf_name stand for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string formatted name. Open to better names though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test???????!?!??!?!?!?!

ok ok ok. The FormattedCheckpointFiles class was hacked together pretty quickly, first just wanted to make sure it made sense to people

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 85.71429% with 7 lines in your changes missing coverage. Please review.

Project coverage is 70.74%. Comparing base (7148102) to head (ce5433b).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 0.00% 2 Missing ⚠️
torchtune/training/checkpointing/_checkpointer.py 50.00% 2 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 1 Missing ⚠️
torchtune/config/_utils.py 50.00% 1 Missing ⚠️
torchtune/models/llama3_1/_model_builders.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1232       +/-   ##
===========================================
+ Coverage   26.96%   70.74%   +43.78%     
===========================================
  Files         288      288               
  Lines       14215    14245       +30     
===========================================
+ Hits         3833    10078     +6245     
+ Misses      10382     4167     -6215     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers ebsmothers merged commit 27d103b into pytorch:main Sep 17, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants