-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QLoRA with Llama 3.1 405B #1232
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1232
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8b1a33d with merge base 7045e96 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -63,6 +63,27 @@ def llama3_1_70b() -> TransformerDecoder: | |||
) | |||
|
|||
|
|||
def llama3_1_405b() -> TransformerDecoder: | |||
""" | |||
Builder for creating a Llama3 model initialized w/ the default 405B parameter values. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Llama3.1
here and in the return
torchtune/utils/_distributed.py
Outdated
Args: | ||
model (FSDPModule): wrapped module | ||
is_rank_zero (bool): flag to check if the process is on rank 0 | ||
trainable_only (bool): flag to only return state dict of trainable parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How will this work with #1220?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initially I thought just saving the adapter would be enough, but for the 405B model, consolidating the state_dict is also prohibitively expensive. This update can be added to the other recipes after this lands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No NCCL timeouts?
Couple of small comments (looks like CI is failing too?). This is huge. Two points as an an aside:
|
The only reason this is dev is that FSDP2 support just dropped in PyTorch stable a few days ago, so we want to play around with it more before moving to stable. We need FSDP2 support b/c it plays nicely with QLoRA.
We'd love this - I can definitely promote on Discord and Twitter and start on a formal blog writeup, but would appreciate any help here :) |
I'd love to! Happy to help out here in any capacity - reviewing to help fleshing out a rough sketch. |
@@ -7,10 +7,13 @@ | |||
from ._component_builders import llama3_1, lora_llama3_1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
linter unhappy : ( need to add to __all__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's in __all__
, no idea what's triggering this. The pre-commit flake8 passes too
@@ -19,9 +22,12 @@ | |||
"llama3_1", | |||
"llama3_1_8b", | |||
"llama3_1_70b", | |||
"llama3_1_405b" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing a comma here, friend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:face_palm:
"lora_llama3_1", | ||
"lora_llama3_1_8b", | ||
"lora_llama3_1_70b", | ||
"lora_llama3_1_405b" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/ | ||
checkpoint_files: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guys I think the time has come.. we need to refactor checkpoint_files
field to optionally support the combination of (a) number of files and (b) strf filename format (or something like that). This is like 5 minutes of work and will save us from abominations like this. cc @joecummings
pydoclint................................................................Failed
- hook id: pydoclint
- exit code: 1
Skipping files that match this pattern: tests/torchtune/models/(\w+)/scripts/
recipes/dev/lora_finetune_fsdp2.py
torchtune/_recipe_registry.py
torchtune/models/llama3_1/__init__.py
torchtune/utils/_checkpointing/_checkpointer.py
torchtune/utils/_distributed.py
torchtune/utils/_distributed.py
361: DOC201: Function `get_full_model_state_dict` does not have a return section in docstring
361: DOC501: Function `get_full_model_state_dict` has "raise" statements, but the docstring does not have a "Raises" section
Loading config from user-specified .toml file: pyproject.toml
Found options defined in pyproject.toml:
{'style': 'google', 'check_return_types': 'False', 'exclude': 'tests/torchtune/models/(\\w+)/scripts/'} welcome back to hell, brother |
Currently hitting this error with compile, no speedups:
This happens on all ranks. |
@joecummings it's a bit hard to tell exactly what we're recompiling on from the logs (since it just includes a type id). But one known cause of recompilations using DTensor today is that we overspecialize on all of the DTensor metadata ( This PR fixes it, so once it lands can you try again? (and we can stare at the recompiles some more if that doesn't fix it) pytorch/pytorch#124401 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not allow a regex pattern for files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test???????!?!??!?!?!?!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what the hell does strf_name
stand for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string formatted name. Open to better names though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test???????!?!??!?!?!?!
ok ok ok. The FormattedCheckpointFiles
class was hacked together pretty quickly, first just wanted to make sure it made sense to people
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1232 +/- ##
===========================================
+ Coverage 26.96% 70.74% +43.78%
===========================================
Files 288 288
Lines 14215 14245 +30
===========================================
+ Hits 3833 10078 +6245
+ Misses 10382 4167 -6215 ☔ View full report in Codecov by Sentry. |
Edit: this PR has been resuscitated. The main change since the previous version is adding a new way to pass checkpoint files. Now we can just do
without having to copy-paste 191 rows of config. We should integrate this into all our 70B configs too when we get a chance.
Test plan
Added some new tests for the
FormattedCheckpointFiles
class.E2E test:
Setting aside the fact that our logs are annoying here, this works.
Context
What is the purpose of this PR? Is it to
This PR adds a Lllama 405B QLoRA config for the lora_finetune_fsdp2 recipe. This config requires 8 x A100s to run. The config successfully finetunes Llama 405B on the alpaca dataset as an example. There are several caveats for getting this large model to fit on a single node:
This PR was jointly made with me and @joecummings
Changelog
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)
pre-commit install
)pytest tests
pytest tests -m integration_test
Test Memory + Checkpointing
tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora log_peak_memory_stats=True dataset.packed=True dataset.max_seq_len=2048 max_steps_per_epoch=2 epochs=2
Test Model Loss
tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3_1/405B_qlora