-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QLoRA with Llama 3.1 405B #1232
Changes from 14 commits
6745533
86e5389
3d61291
8799da4
3a6de7b
5d8e063
5760fd5
3bc86b5
35f24ac
0566636
9b8719c
f615bc4
490f18a
4d104ce
319f476
3055806
a79f01a
7a97cd9
3148418
542c968
1f7a78a
81a0eb4
610702a
d6bbc5e
a930f88
977efb1
cd17e6b
f147667
279fde2
26bf19a
ce5433b
a2a7afe
8b1a33d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
# Config for multi-device QLoRA in lora_finetune_fsdp2.py | ||
# using a Llama3.1 405B model | ||
# | ||
# This config requires PyTorch nightlies to run. | ||
# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md | ||
# for setup instructions. | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download sllhf/Meta-Llama-3.1-405B-Instruct --hf-token <HF_TOKEN> --ignore-patterns "original/consolidated*" | ||
# | ||
# This config needs 8 GPUs to run | ||
# # tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3.1/405B_qlora | ||
# | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_1.qlora_llama3_1_405b | ||
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj'] | ||
apply_lora_to_mlp: True | ||
apply_lora_to_output: False | ||
lora_rank: 16 | ||
lora_alpha: 32 | ||
|
||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Meta-Llama-3.1-405B-Instruct/original/tokenizer.model | ||
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/ | ||
checkpoint_files: [ | ||
model-00001-of-00191.safetensors, | ||
model-00002-of-00191.safetensors, | ||
model-00003-of-00191.safetensors, | ||
model-00004-of-00191.safetensors, | ||
model-00005-of-00191.safetensors, | ||
model-00006-of-00191.safetensors, | ||
model-00007-of-00191.safetensors, | ||
model-00008-of-00191.safetensors, | ||
model-00009-of-00191.safetensors, | ||
model-00010-of-00191.safetensors, | ||
model-00011-of-00191.safetensors, | ||
model-00012-of-00191.safetensors, | ||
model-00013-of-00191.safetensors, | ||
model-00014-of-00191.safetensors, | ||
model-00015-of-00191.safetensors, | ||
model-00016-of-00191.safetensors, | ||
model-00017-of-00191.safetensors, | ||
model-00018-of-00191.safetensors, | ||
model-00019-of-00191.safetensors, | ||
model-00020-of-00191.safetensors, | ||
model-00021-of-00191.safetensors, | ||
model-00022-of-00191.safetensors, | ||
model-00023-of-00191.safetensors, | ||
model-00024-of-00191.safetensors, | ||
model-00025-of-00191.safetensors, | ||
model-00026-of-00191.safetensors, | ||
model-00027-of-00191.safetensors, | ||
model-00028-of-00191.safetensors, | ||
model-00029-of-00191.safetensors, | ||
model-00030-of-00191.safetensors, | ||
model-00031-of-00191.safetensors, | ||
model-00032-of-00191.safetensors, | ||
model-00033-of-00191.safetensors, | ||
model-00034-of-00191.safetensors, | ||
model-00035-of-00191.safetensors, | ||
model-00036-of-00191.safetensors, | ||
model-00037-of-00191.safetensors, | ||
model-00038-of-00191.safetensors, | ||
model-00039-of-00191.safetensors, | ||
model-00040-of-00191.safetensors, | ||
model-00041-of-00191.safetensors, | ||
model-00042-of-00191.safetensors, | ||
model-00043-of-00191.safetensors, | ||
model-00044-of-00191.safetensors, | ||
model-00045-of-00191.safetensors, | ||
model-00046-of-00191.safetensors, | ||
model-00047-of-00191.safetensors, | ||
model-00048-of-00191.safetensors, | ||
model-00049-of-00191.safetensors, | ||
model-00050-of-00191.safetensors, | ||
model-00051-of-00191.safetensors, | ||
model-00052-of-00191.safetensors, | ||
model-00053-of-00191.safetensors, | ||
model-00054-of-00191.safetensors, | ||
model-00055-of-00191.safetensors, | ||
model-00056-of-00191.safetensors, | ||
model-00057-of-00191.safetensors, | ||
model-00058-of-00191.safetensors, | ||
model-00059-of-00191.safetensors, | ||
model-00060-of-00191.safetensors, | ||
model-00061-of-00191.safetensors, | ||
model-00062-of-00191.safetensors, | ||
model-00063-of-00191.safetensors, | ||
model-00064-of-00191.safetensors, | ||
model-00065-of-00191.safetensors, | ||
model-00066-of-00191.safetensors, | ||
model-00067-of-00191.safetensors, | ||
model-00068-of-00191.safetensors, | ||
model-00069-of-00191.safetensors, | ||
model-00070-of-00191.safetensors, | ||
model-00071-of-00191.safetensors, | ||
model-00072-of-00191.safetensors, | ||
model-00073-of-00191.safetensors, | ||
model-00074-of-00191.safetensors, | ||
model-00075-of-00191.safetensors, | ||
model-00076-of-00191.safetensors, | ||
model-00077-of-00191.safetensors, | ||
model-00078-of-00191.safetensors, | ||
model-00079-of-00191.safetensors, | ||
model-00080-of-00191.safetensors, | ||
model-00081-of-00191.safetensors, | ||
model-00082-of-00191.safetensors, | ||
model-00083-of-00191.safetensors, | ||
model-00084-of-00191.safetensors, | ||
model-00085-of-00191.safetensors, | ||
model-00086-of-00191.safetensors, | ||
model-00087-of-00191.safetensors, | ||
model-00088-of-00191.safetensors, | ||
model-00089-of-00191.safetensors, | ||
model-00090-of-00191.safetensors, | ||
model-00091-of-00191.safetensors, | ||
model-00092-of-00191.safetensors, | ||
model-00093-of-00191.safetensors, | ||
model-00094-of-00191.safetensors, | ||
model-00095-of-00191.safetensors, | ||
model-00096-of-00191.safetensors, | ||
model-00097-of-00191.safetensors, | ||
model-00098-of-00191.safetensors, | ||
model-00099-of-00191.safetensors, | ||
model-00100-of-00191.safetensors, | ||
model-00101-of-00191.safetensors, | ||
model-00102-of-00191.safetensors, | ||
model-00103-of-00191.safetensors, | ||
model-00104-of-00191.safetensors, | ||
model-00105-of-00191.safetensors, | ||
model-00106-of-00191.safetensors, | ||
model-00107-of-00191.safetensors, | ||
model-00108-of-00191.safetensors, | ||
model-00109-of-00191.safetensors, | ||
model-00110-of-00191.safetensors, | ||
model-00111-of-00191.safetensors, | ||
model-00112-of-00191.safetensors, | ||
model-00113-of-00191.safetensors, | ||
model-00114-of-00191.safetensors, | ||
model-00115-of-00191.safetensors, | ||
model-00116-of-00191.safetensors, | ||
model-00117-of-00191.safetensors, | ||
model-00118-of-00191.safetensors, | ||
model-00119-of-00191.safetensors, | ||
model-00120-of-00191.safetensors, | ||
model-00121-of-00191.safetensors, | ||
model-00122-of-00191.safetensors, | ||
model-00123-of-00191.safetensors, | ||
model-00124-of-00191.safetensors, | ||
model-00125-of-00191.safetensors, | ||
model-00126-of-00191.safetensors, | ||
model-00127-of-00191.safetensors, | ||
model-00128-of-00191.safetensors, | ||
model-00129-of-00191.safetensors, | ||
model-00130-of-00191.safetensors, | ||
model-00131-of-00191.safetensors, | ||
model-00132-of-00191.safetensors, | ||
model-00133-of-00191.safetensors, | ||
model-00134-of-00191.safetensors, | ||
model-00135-of-00191.safetensors, | ||
model-00136-of-00191.safetensors, | ||
model-00137-of-00191.safetensors, | ||
model-00138-of-00191.safetensors, | ||
model-00139-of-00191.safetensors, | ||
model-00140-of-00191.safetensors, | ||
model-00141-of-00191.safetensors, | ||
model-00142-of-00191.safetensors, | ||
model-00143-of-00191.safetensors, | ||
model-00144-of-00191.safetensors, | ||
model-00145-of-00191.safetensors, | ||
model-00146-of-00191.safetensors, | ||
model-00147-of-00191.safetensors, | ||
model-00148-of-00191.safetensors, | ||
model-00149-of-00191.safetensors, | ||
model-00150-of-00191.safetensors, | ||
model-00151-of-00191.safetensors, | ||
model-00152-of-00191.safetensors, | ||
model-00153-of-00191.safetensors, | ||
model-00154-of-00191.safetensors, | ||
model-00155-of-00191.safetensors, | ||
model-00156-of-00191.safetensors, | ||
model-00157-of-00191.safetensors, | ||
model-00158-of-00191.safetensors, | ||
model-00159-of-00191.safetensors, | ||
model-00160-of-00191.safetensors, | ||
model-00161-of-00191.safetensors, | ||
model-00162-of-00191.safetensors, | ||
model-00163-of-00191.safetensors, | ||
model-00164-of-00191.safetensors, | ||
model-00165-of-00191.safetensors, | ||
model-00166-of-00191.safetensors, | ||
model-00167-of-00191.safetensors, | ||
model-00168-of-00191.safetensors, | ||
model-00169-of-00191.safetensors, | ||
model-00170-of-00191.safetensors, | ||
model-00171-of-00191.safetensors, | ||
model-00172-of-00191.safetensors, | ||
model-00173-of-00191.safetensors, | ||
model-00174-of-00191.safetensors, | ||
model-00175-of-00191.safetensors, | ||
model-00176-of-00191.safetensors, | ||
model-00177-of-00191.safetensors, | ||
model-00178-of-00191.safetensors, | ||
model-00179-of-00191.safetensors, | ||
model-00180-of-00191.safetensors, | ||
model-00181-of-00191.safetensors, | ||
model-00182-of-00191.safetensors, | ||
model-00183-of-00191.safetensors, | ||
model-00184-of-00191.safetensors, | ||
model-00185-of-00191.safetensors, | ||
model-00186-of-00191.safetensors, | ||
model-00187-of-00191.safetensors, | ||
model-00188-of-00191.safetensors, | ||
model-00189-of-00191.safetensors, | ||
model-00190-of-00191.safetensors, | ||
model-00191-of-00191.safetensors, | ||
] | ||
recipe_checkpoint: null | ||
output_dir: /tmp/Meta-Llama-3.1-405B-Instruct/ | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
merge_checkpoint: False | ||
|
||
# Dataset and Sampler | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
train_on_input: True | ||
seed: null | ||
shuffle: True | ||
batch_size: 1 | ||
|
||
# Optimizer and Scheduler | ||
optimizer: | ||
_component_: torch.optim.AdamW | ||
weight_decay: 0.01 | ||
lr: 3e-4 | ||
fused: True | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 100 | ||
|
||
loss: | ||
_component_: torch.nn.CrossEntropyLoss | ||
|
||
fsdp: | ||
cpu_offload: False | ||
|
||
# Training | ||
epochs: 1 | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 16 | ||
|
||
# Logging | ||
output_dir: /tmp/qlora_finetune_output | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: False | ||
|
||
# Environment | ||
device: cuda | ||
dtype: bf16 | ||
enable_activation_checkpointing: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,10 +7,13 @@ | |
from ._component_builders import llama3_1, lora_llama3_1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. linter unhappy : ( need to add to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's in |
||
|
||
from ._model_builders import ( # noqa | ||
llama3_1_405b, | ||
llama3_1_70b, | ||
llama3_1_8b, | ||
lora_llama3_1_405b, | ||
lora_llama3_1_70b, | ||
lora_llama3_1_8b, | ||
qlora_llama3_1_405b, | ||
qlora_llama3_1_70b, | ||
qlora_llama3_1_8b, | ||
) | ||
|
@@ -19,9 +22,12 @@ | |
"llama3_1", | ||
"llama3_1_8b", | ||
"llama3_1_70b", | ||
"llama3_1_405b" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing a comma here, friend There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :face_palm: |
||
"lora_llama3_1", | ||
"lora_llama3_1_8b", | ||
"lora_llama3_1_70b", | ||
"lora_llama3_1_405b" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also here |
||
"qlora_llama3_1_8b", | ||
"qlora_llama3_1_70b", | ||
"qlora_llama3_1_405b", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guys I think the time has come.. we need to refactor
checkpoint_files
field to optionally support the combination of (a) number of files and (b) strf filename format (or something like that). This is like 5 minutes of work and will save us from abominations like this. cc @joecummings