Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added adapter_only option to LoRA #1220

Merged
merged 13 commits into from
Jul 29, 2024
5 changes: 4 additions & 1 deletion docs/source/deep_dives/checkpointer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ Checkpointing for LoRA
In torchtune, we output both the adapter weights and the full model "merged" weights
for LoRA. The "merged" checkpoint can be used just like you would use the source
checkpoint with any post-training tools. For more details, take a look at our
:ref:`LoRA Finetuning Tutorial <lora_finetune_label>`.
:ref:`LoRA Finetuning Tutorial <lora_finetune_label>`.Additionally, by setting the option "save_adapter_weights_only" to True when saving a checkpoint, you can choose to save only the adapter weights.

The primary difference between the two use cases is when you want to resume training
from a checkpoint. In this case, the checkpointer needs access to both the initial frozen
Expand Down Expand Up @@ -404,6 +404,9 @@ looks something like this:
# set to True if restarting training
resume_from_checkpoint: True

# Set to True to save only the adapter weights
save_adapter_weights_only: False

|

Putting this all together
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ checkpointer:
output_dir: /tmp/CodeLlama-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Fine-tuning arguments
batch_size: 2
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ checkpointer:
output_dir: /tmp/CodeLlama-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Fine-tuning arguments and training
batch_size: 2
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/llama2/13B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ checkpointer:
output_dir: /tmp/Llama-2-13b-hf/
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Tokenizer
tokenizer:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/llama2/70B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ checkpointer:
output_dir: /tmp/Llama-2-70b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ checkpointer:
output_dir: /tmp/Llama-2-70b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/llama2/7B_lora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ checkpointer:
output_dir: /tmp/gemma-2b
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ checkpointer:
output_dir: /tmp/gemma-2b
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ checkpointer:
output_dir: /tmp/gemma-2b
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ checkpointer:
output_dir: /tmp/gemma-7b/
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ checkpointer:
output_dir: /tmp/gemma-7b/
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ checkpointer:
output_dir: /tmp/gemma-7b/
model_type: GEMMA
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ checkpointer:
output_dir: /tmp/Llama-2-13b-hf/
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Tokenizer
tokenizer:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ checkpointer:
output_dir: /tmp/Llama-2-13b-hf/
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ checkpointer:
output_dir: /tmp/Llama-2-70b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_dpo_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ checkpointer:
output_dir: /tmp/Llama-2-7b-hf
model_type: LLAMA2
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3-70B-Instruct
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/70B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ checkpointer:
output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset and Sampler
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ checkpointer:
output_dir: /tmp/Mistral-7B-v0.1
model_type: MISTRAL
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ checkpointer:
output_dir: /tmp/Mistral-7B-v0.1
model_type: MISTRAL
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ checkpointer:
output_dir: /tmp/Mistral-7B-v0.1
model_type: MISTRAL
resume_from_checkpoint: False
save_adapter_weights_only: False

optimizer:
_component_: torch.optim.AdamW
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ checkpointer:
output_dir: /tmp/Phi-3-mini-4k-instruct
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ checkpointer:
output_dir: /tmp/Phi-3-mini-4k-instruct
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset
dataset:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/phi3/mini_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ checkpointer:
output_dir: /tmp/Phi-3-mini-4k-instruct
model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset
dataset:
Expand Down
8 changes: 4 additions & 4 deletions recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0

self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
Expand Down Expand Up @@ -470,10 +470,9 @@ def save_checkpoint(
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights

Checkpointer will save the merged weights, adapter weights and recipe state in
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights.
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
"""
# final dict passed onto the checkpointer
checkpoint_dict = {}
Expand Down Expand Up @@ -530,6 +529,7 @@ def save_checkpoint(
checkpoint_dict,
epoch=epoch,
intermediate_checkpoint=intermediate_checkpoint,
adapter_only=self._save_adapter_weights_only,
)

def concatenated_forward(
Expand Down
15 changes: 9 additions & 6 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0

self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
Expand Down Expand Up @@ -361,14 +361,15 @@ def save_checkpoint(self, epoch: int) -> None:
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights

Checkpointer will save the merged weights, adapter weights and recipe state in
different checkpoint files. To correctly resume from training, the adapter weights
and recipe state must be provided along with the base model weights.
To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights.
"""
ckpt_dict = {}

intermediate_checkpoint = epoch + 1 < self.total_epochs
# if training is in-progress, checkpoint the optimizer state as well
if epoch + 1 < self.total_epochs:
if intermediate_checkpoint:
ckpt_dict.update(
{
utils.OPT_KEY: self._optimizer.state_dict(),
Expand Down Expand Up @@ -396,10 +397,12 @@ def save_checkpoint(self, epoch: int) -> None:
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
}
ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict})

self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
intermediate_checkpoint=(epoch + 1 < self.total_epochs),
intermediate_checkpoint=intermediate_checkpoint,
adapter_only=self._save_adapter_weights_only,
)

def concatenated_forward(
Expand Down
Loading
Loading