Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QLoRA with Llama 3.1 405B #1232

Merged
merged 33 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6745533
Update weights and docs
joecummings Jul 22, 2024
86e5389
Merge remote-tracking branch 'upstream/main' into update-weights-4-llama
joecummings Jul 22, 2024
3d61291
First commit
joecummings Jul 23, 2024
8799da4
cleaned config and added to registry
Jul 23, 2024
3a6de7b
Merge branch 'main' into experimental-405b
Jul 23, 2024
5d8e063
Updated builders to 3_1
Jul 23, 2024
5760fd5
Updated docs
Jul 23, 2024
3bc86b5
Merge branch 'pytorch:main' into experimental-405b
pbontrager Jul 24, 2024
35f24ac
Only save adapter every epoch
Jul 25, 2024
0566636
Merge branch 'main' into experimental-405b
Jul 25, 2024
9b8719c
checkpointer not require model params
pbontrager Jul 25, 2024
f615bc4
made merge_checkpoint optional
pbontrager Jul 26, 2024
490f18a
Merge branch 'main' into experimental-405b
pbontrager Jul 26, 2024
4d104ce
Fixed based on comments and tests
pbontrager Jul 26, 2024
319f476
missing commas
pbontrager Jul 26, 2024
3055806
Merge remote-tracking branch 'upstream/main' into experimental-405b
joecummings Jul 29, 2024
a79f01a
chore: lint
joecummings Jul 29, 2024
7a97cd9
Merge branch 'experimental-405b' of github.com:joecummings/torchtune …
joecummings Jul 29, 2024
3148418
Add compile support
joecummings Jul 30, 2024
542c968
asdf
joecummings Jul 30, 2024
1f7a78a
Lint
joecummings Aug 5, 2024
81a0eb4
Changes
joecummings Aug 6, 2024
610702a
merge with main
ebsmothers Sep 14, 2024
d6bbc5e
wip changes for formatted checkpoint files
ebsmothers Sep 14, 2024
a930f88
wip changes
ebsmothers Sep 15, 2024
977efb1
training starts
ebsmothers Sep 16, 2024
cd17e6b
some cleanup
ebsmothers Sep 16, 2024
f147667
checkpoint save and cleanup
ebsmothers Sep 16, 2024
279fde2
test and docstring
ebsmothers Sep 16, 2024
26bf19a
expose in docs, fix config field name
ebsmothers Sep 16, 2024
ce5433b
merge
ebsmothers Sep 16, 2024
a2a7afe
doc update
ebsmothers Sep 16, 2024
8b1a33d
merge
ebsmothers Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/source/api_ref_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ To download the Llama3-70B-Instruct model:

tune download meta-llama/Meta-Llama-3-70B-Instruct --hf-token <HF_TOKEN> --ignore-patterns "original/consolidated*"

To download the Llama3.1 weights of the above models, you can instead download from `Meta-Llama-3.1-8B-Instruct`
or `Meta-Llama-3.1-70B-Instruct`.
To download the Llama3.1 weights of the above models, you can instead download from `Meta-Llama-3.1-8B-Instruct`,
`Meta-Llama-3.1-70B-Instruct`, or `Meta-Llama-3.1-405B-Instruct`.

.. autosummary::
:toctree: generated/
Expand All @@ -53,6 +53,9 @@ or `Meta-Llama-3.1-70B-Instruct`.
llama3_1.llama3_1_70b
llama3_1.lora_llama3_1_70b
llama3_1.qlora_llama3_1_70b
llama3_1.llama3_1_405b
llama3_1.lora_llama3_1_405b
llama3_1.qlora_llama3_1_405b


.. note::
Expand Down
271 changes: 271 additions & 0 deletions recipes/configs/dev/llama3_1/405B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Config for multi-device QLoRA in lora_finetune_fsdp2.py
# using a Llama3.1 405B model
#
# This config requires PyTorch nightlies to run.
# See https://github.com/pytorch/torchtune/blob/main/recipes/dev/fsdp2_recipes.md
# for setup instructions.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download sllhf/Meta-Llama-3.1-405B-Instruct --hf-token <HF_TOKEN> --ignore-patterns "original/consolidated*"
#
# This config needs 8 GPUs to run
# # tune run --nproc_per_node 8 lora_finetune_fsdp2 --config llama3.1/405B_qlora
#

# Model Arguments
model:
_component_: torchtune.models.llama3_1.qlora_llama3_1_405b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'output_proj']
apply_lora_to_mlp: True
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32

tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-405B-Instruct/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-405B-Instruct/
checkpoint_files: [
Copy link
Contributor

@ebsmothers ebsmothers Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guys I think the time has come.. we need to refactor checkpoint_files field to optionally support the combination of (a) number of files and (b) strf filename format (or something like that). This is like 5 minutes of work and will save us from abominations like this. cc @joecummings

model-00001-of-00191.safetensors,
model-00002-of-00191.safetensors,
model-00003-of-00191.safetensors,
model-00004-of-00191.safetensors,
model-00005-of-00191.safetensors,
model-00006-of-00191.safetensors,
model-00007-of-00191.safetensors,
model-00008-of-00191.safetensors,
model-00009-of-00191.safetensors,
model-00010-of-00191.safetensors,
model-00011-of-00191.safetensors,
model-00012-of-00191.safetensors,
model-00013-of-00191.safetensors,
model-00014-of-00191.safetensors,
model-00015-of-00191.safetensors,
model-00016-of-00191.safetensors,
model-00017-of-00191.safetensors,
model-00018-of-00191.safetensors,
model-00019-of-00191.safetensors,
model-00020-of-00191.safetensors,
model-00021-of-00191.safetensors,
model-00022-of-00191.safetensors,
model-00023-of-00191.safetensors,
model-00024-of-00191.safetensors,
model-00025-of-00191.safetensors,
model-00026-of-00191.safetensors,
model-00027-of-00191.safetensors,
model-00028-of-00191.safetensors,
model-00029-of-00191.safetensors,
model-00030-of-00191.safetensors,
model-00031-of-00191.safetensors,
model-00032-of-00191.safetensors,
model-00033-of-00191.safetensors,
model-00034-of-00191.safetensors,
model-00035-of-00191.safetensors,
model-00036-of-00191.safetensors,
model-00037-of-00191.safetensors,
model-00038-of-00191.safetensors,
model-00039-of-00191.safetensors,
model-00040-of-00191.safetensors,
model-00041-of-00191.safetensors,
model-00042-of-00191.safetensors,
model-00043-of-00191.safetensors,
model-00044-of-00191.safetensors,
model-00045-of-00191.safetensors,
model-00046-of-00191.safetensors,
model-00047-of-00191.safetensors,
model-00048-of-00191.safetensors,
model-00049-of-00191.safetensors,
model-00050-of-00191.safetensors,
model-00051-of-00191.safetensors,
model-00052-of-00191.safetensors,
model-00053-of-00191.safetensors,
model-00054-of-00191.safetensors,
model-00055-of-00191.safetensors,
model-00056-of-00191.safetensors,
model-00057-of-00191.safetensors,
model-00058-of-00191.safetensors,
model-00059-of-00191.safetensors,
model-00060-of-00191.safetensors,
model-00061-of-00191.safetensors,
model-00062-of-00191.safetensors,
model-00063-of-00191.safetensors,
model-00064-of-00191.safetensors,
model-00065-of-00191.safetensors,
model-00066-of-00191.safetensors,
model-00067-of-00191.safetensors,
model-00068-of-00191.safetensors,
model-00069-of-00191.safetensors,
model-00070-of-00191.safetensors,
model-00071-of-00191.safetensors,
model-00072-of-00191.safetensors,
model-00073-of-00191.safetensors,
model-00074-of-00191.safetensors,
model-00075-of-00191.safetensors,
model-00076-of-00191.safetensors,
model-00077-of-00191.safetensors,
model-00078-of-00191.safetensors,
model-00079-of-00191.safetensors,
model-00080-of-00191.safetensors,
model-00081-of-00191.safetensors,
model-00082-of-00191.safetensors,
model-00083-of-00191.safetensors,
model-00084-of-00191.safetensors,
model-00085-of-00191.safetensors,
model-00086-of-00191.safetensors,
model-00087-of-00191.safetensors,
model-00088-of-00191.safetensors,
model-00089-of-00191.safetensors,
model-00090-of-00191.safetensors,
model-00091-of-00191.safetensors,
model-00092-of-00191.safetensors,
model-00093-of-00191.safetensors,
model-00094-of-00191.safetensors,
model-00095-of-00191.safetensors,
model-00096-of-00191.safetensors,
model-00097-of-00191.safetensors,
model-00098-of-00191.safetensors,
model-00099-of-00191.safetensors,
model-00100-of-00191.safetensors,
model-00101-of-00191.safetensors,
model-00102-of-00191.safetensors,
model-00103-of-00191.safetensors,
model-00104-of-00191.safetensors,
model-00105-of-00191.safetensors,
model-00106-of-00191.safetensors,
model-00107-of-00191.safetensors,
model-00108-of-00191.safetensors,
model-00109-of-00191.safetensors,
model-00110-of-00191.safetensors,
model-00111-of-00191.safetensors,
model-00112-of-00191.safetensors,
model-00113-of-00191.safetensors,
model-00114-of-00191.safetensors,
model-00115-of-00191.safetensors,
model-00116-of-00191.safetensors,
model-00117-of-00191.safetensors,
model-00118-of-00191.safetensors,
model-00119-of-00191.safetensors,
model-00120-of-00191.safetensors,
model-00121-of-00191.safetensors,
model-00122-of-00191.safetensors,
model-00123-of-00191.safetensors,
model-00124-of-00191.safetensors,
model-00125-of-00191.safetensors,
model-00126-of-00191.safetensors,
model-00127-of-00191.safetensors,
model-00128-of-00191.safetensors,
model-00129-of-00191.safetensors,
model-00130-of-00191.safetensors,
model-00131-of-00191.safetensors,
model-00132-of-00191.safetensors,
model-00133-of-00191.safetensors,
model-00134-of-00191.safetensors,
model-00135-of-00191.safetensors,
model-00136-of-00191.safetensors,
model-00137-of-00191.safetensors,
model-00138-of-00191.safetensors,
model-00139-of-00191.safetensors,
model-00140-of-00191.safetensors,
model-00141-of-00191.safetensors,
model-00142-of-00191.safetensors,
model-00143-of-00191.safetensors,
model-00144-of-00191.safetensors,
model-00145-of-00191.safetensors,
model-00146-of-00191.safetensors,
model-00147-of-00191.safetensors,
model-00148-of-00191.safetensors,
model-00149-of-00191.safetensors,
model-00150-of-00191.safetensors,
model-00151-of-00191.safetensors,
model-00152-of-00191.safetensors,
model-00153-of-00191.safetensors,
model-00154-of-00191.safetensors,
model-00155-of-00191.safetensors,
model-00156-of-00191.safetensors,
model-00157-of-00191.safetensors,
model-00158-of-00191.safetensors,
model-00159-of-00191.safetensors,
model-00160-of-00191.safetensors,
model-00161-of-00191.safetensors,
model-00162-of-00191.safetensors,
model-00163-of-00191.safetensors,
model-00164-of-00191.safetensors,
model-00165-of-00191.safetensors,
model-00166-of-00191.safetensors,
model-00167-of-00191.safetensors,
model-00168-of-00191.safetensors,
model-00169-of-00191.safetensors,
model-00170-of-00191.safetensors,
model-00171-of-00191.safetensors,
model-00172-of-00191.safetensors,
model-00173-of-00191.safetensors,
model-00174-of-00191.safetensors,
model-00175-of-00191.safetensors,
model-00176-of-00191.safetensors,
model-00177-of-00191.safetensors,
model-00178-of-00191.safetensors,
model-00179-of-00191.safetensors,
model-00180-of-00191.safetensors,
model-00181-of-00191.safetensors,
model-00182-of-00191.safetensors,
model-00183-of-00191.safetensors,
model-00184-of-00191.safetensors,
model-00185-of-00191.safetensors,
model-00186-of-00191.safetensors,
model-00187-of-00191.safetensors,
model-00188-of-00191.safetensors,
model-00189-of-00191.safetensors,
model-00190-of-00191.safetensors,
model-00191-of-00191.safetensors,
]
recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3.1-405B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
merge_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 1

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
fused: True
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

fsdp:
cpu_offload: False

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 16

# Logging
output_dir: /tmp/qlora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
24 changes: 16 additions & 8 deletions recipes/dev/lora_finetune_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0

self._merge_checkpoint = cfg.get("merge_checkpoint", True)
self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

Expand Down Expand Up @@ -463,15 +464,20 @@ def save_checkpoint(
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights.
"""
# Only build adapter
# If not adapter only, load weights

# final dict passed onto the checkpointer
checkpoint_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs
merge_checkpoint = self._merge_checkpoint and not intermediate_checkpoint
# To prevent GPU memory from spiking during checkpoint save,
# we consolidate the full model and optim state dicts on CPU for rank 0
cpu_state_dict = utils.get_full_model_state_dict(
self._model,
self._is_rank_zero,
trainable_only=not merge_checkpoint,
)

if intermediate_checkpoint:
Expand All @@ -494,14 +500,6 @@ def save_checkpoint(
}
checkpoint_dict.update({utils.ADAPTER_KEY: adapter_state_dict})

# merge the adapter weights and base weights to create the model checkpoint
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict})

# if training is in-progress, checkpoint the optimizer state and recipe state
# as well.
if intermediate_checkpoint:
Expand All @@ -515,6 +513,15 @@ def save_checkpoint(
}
)

# if training is complete, optionally merge adapter with the base weights
if merge_checkpoint:
merged_state_dict = get_merged_lora_ckpt(
cpu_state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
checkpoint_dict.update({utils.MODEL_KEY: merged_state_dict})

adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
Expand All @@ -531,6 +538,7 @@ def save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
adapter_only=not merge_checkpoint,
)

def train(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ class Recipe:
name="llama2/70B_qlora",
file_path="dev/llama2/70B_qlora_fsdp2.yaml",
),
Config(name="llama3_1/405B_qlora", file_path="dev/llama3_1/405B_qlora_fsdp2.yaml"),
],
supports_distributed=True,
),
Expand Down
6 changes: 6 additions & 0 deletions torchtune/models/llama3_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
from ._component_builders import llama3_1, lora_llama3_1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linter unhappy : ( need to add to __all__?

Copy link
Contributor Author

@pbontrager pbontrager Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in __all__, no idea what's triggering this. The pre-commit flake8 passes too


from ._model_builders import ( # noqa
llama3_1_405b,
llama3_1_70b,
llama3_1_8b,
lora_llama3_1_405b,
lora_llama3_1_70b,
lora_llama3_1_8b,
qlora_llama3_1_405b,
qlora_llama3_1_70b,
qlora_llama3_1_8b,
)
Expand All @@ -19,9 +22,12 @@
"llama3_1",
"llama3_1_8b",
"llama3_1_70b",
"llama3_1_405b"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a comma here, friend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:face_palm:

"lora_llama3_1",
"lora_llama3_1_8b",
"lora_llama3_1_70b",
"lora_llama3_1_405b"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

"qlora_llama3_1_8b",
"qlora_llama3_1_70b",
"qlora_llama3_1_405b",
]
Loading
Loading