diff --git a/docs/source/deep_dives/checkpointer.rst b/docs/source/deep_dives/checkpointer.rst index 5fc9873f26..83d8d08cb4 100644 --- a/docs/source/deep_dives/checkpointer.rst +++ b/docs/source/deep_dives/checkpointer.rst @@ -361,7 +361,7 @@ Checkpointing for LoRA In torchtune, we output both the adapter weights and the full model "merged" weights for LoRA. The "merged" checkpoint can be used just like you would use the source checkpoint with any post-training tools. For more details, take a look at our -:ref:`LoRA Finetuning Tutorial `. +:ref:`LoRA Finetuning Tutorial `.Additionally, by setting the option "save_adapter_weights_only" to True when saving a checkpoint, you can choose to save only the adapter weights. The primary difference between the two use cases is when you want to resume training from a checkpoint. In this case, the checkpointer needs access to both the initial frozen @@ -404,6 +404,9 @@ looks something like this: # set to True if restarting training resume_from_checkpoint: True + # Set to True to save only the adapter weights + save_adapter_weights_only: False + | Putting this all together diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index b85117fc57..274d926f77 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -49,6 +49,7 @@ checkpointer: output_dir: /tmp/CodeLlama-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Fine-tuning arguments batch_size: 2 diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 1887ebaf39..c2a149147a 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -49,6 +49,7 @@ checkpointer: output_dir: /tmp/CodeLlama-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Fine-tuning arguments and training batch_size: 2 diff --git a/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml index f7b4fc8b77..f6cbf97e1d 100644 --- a/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml @@ -45,6 +45,7 @@ checkpointer: output_dir: /tmp/Llama-2-13b-hf/ model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Tokenizer tokenizer: diff --git a/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml index c2aac60310..c1dd06572c 100644 --- a/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml @@ -50,6 +50,7 @@ checkpointer: output_dir: /tmp/Llama-2-70b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml b/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml index d8bf55b6bc..8fee1e7346 100644 --- a/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml @@ -50,6 +50,7 @@ checkpointer: output_dir: /tmp/Llama-2-70b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml index aba978486a..cb282365a7 100644 --- a/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml @@ -47,6 +47,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml b/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml index 00dd1e4fe8..4e23f07d30 100644 --- a/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml @@ -46,6 +46,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 3fdeb3f4cf..ae5141d8b2 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -46,6 +46,7 @@ checkpointer: output_dir: /tmp/gemma-2b model_type: GEMMA resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index ce8bc9504c..1e785dcf45 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -45,6 +45,7 @@ checkpointer: output_dir: /tmp/gemma-2b model_type: GEMMA resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index c50d5935ad..ecfbf3da06 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -45,6 +45,7 @@ checkpointer: output_dir: /tmp/gemma-2b model_type: GEMMA resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index 432696d639..fb78a3eebb 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -48,6 +48,7 @@ checkpointer: output_dir: /tmp/gemma-7b/ model_type: GEMMA resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index dfab61e13b..7f9fd7ea39 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -47,6 +47,7 @@ checkpointer: output_dir: /tmp/gemma-7b/ model_type: GEMMA resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 9d737321cf..08f05bc6f5 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -47,6 +47,7 @@ checkpointer: output_dir: /tmp/gemma-7b/ model_type: GEMMA resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index e52549562a..f31229feb9 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -41,6 +41,7 @@ checkpointer: output_dir: /tmp/Llama-2-13b-hf/ model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Tokenizer tokenizer: diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index b0fc9e6a35..ac0bcab468 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -41,6 +41,7 @@ checkpointer: output_dir: /tmp/Llama-2-13b-hf/ model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 46f17fb3c1..ced6edbeb7 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -46,6 +46,7 @@ checkpointer: output_dir: /tmp/Llama-2-70b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 9a65921503..0cfcef26d2 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -43,6 +43,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama2/7B_lora_dpo.yaml b/recipes/configs/llama2/7B_lora_dpo.yaml index b166db8bd7..59435c70d3 100644 --- a/recipes/configs/llama2/7B_lora_dpo.yaml +++ b/recipes/configs/llama2/7B_lora_dpo.yaml @@ -43,6 +43,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml index 24f815dfb9..088303ca57 100644 --- a/recipes/configs/llama2/7B_lora_dpo_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_dpo_single_device.yaml @@ -42,6 +42,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index ca5603d13e..b68f851ff4 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -41,6 +41,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index d32dcfac8e..496d02a8a6 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -40,6 +40,7 @@ checkpointer: output_dir: /tmp/Llama-2-7b-hf model_type: LLAMA2 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index f5e5224040..0dcd9e84b3 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -61,6 +61,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-70B-Instruct model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 27354b7fde..bd4a621a03 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -41,6 +41,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index e375c76c31..650b2ae79e 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -40,6 +40,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index 2e7fefcd08..d80692725b 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -39,6 +39,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index 612048536a..fce6cf5118 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -60,6 +60,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3.1-70B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index d3e3be5af8..e2ff9fbf53 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -44,6 +44,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index bbb0339a7d..57a7b703a2 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -43,6 +43,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 29fcae8351..3475fcee72 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -42,6 +42,7 @@ checkpointer: output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ model_type: LLAMA3 resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset and Sampler dataset: diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index 43798fed6e..a84b714119 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -53,6 +53,7 @@ checkpointer: output_dir: /tmp/Mistral-7B-v0.1 model_type: MISTRAL resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index 81501bb563..23524b7438 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -50,6 +50,7 @@ checkpointer: output_dir: /tmp/Mistral-7B-v0.1 model_type: MISTRAL resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 1593d398d3..f571ce62cb 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -51,6 +51,7 @@ checkpointer: output_dir: /tmp/Mistral-7B-v0.1 model_type: MISTRAL resume_from_checkpoint: False +save_adapter_weights_only: False optimizer: _component_: torch.optim.AdamW diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index b44addcf2e..d5f3bbf584 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -43,6 +43,7 @@ checkpointer: output_dir: /tmp/Phi-3-mini-4k-instruct model_type: PHI3_MINI resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset dataset: diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index 107250b166..1a307a6b67 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -41,6 +41,7 @@ checkpointer: output_dir: /tmp/Phi-3-mini-4k-instruct model_type: PHI3_MINI resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset dataset: diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index 0b5032b513..dbaf1eed4d 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -41,6 +41,7 @@ checkpointer: output_dir: /tmp/Phi-3-mini-4k-instruct model_type: PHI3_MINI resume_from_checkpoint: False +save_adapter_weights_only: False # Dataset dataset: diff --git a/recipes/lora_dpo_distributed.py b/recipes/lora_dpo_distributed.py index 53fde99161..f695525be9 100644 --- a/recipes/lora_dpo_distributed.py +++ b/recipes/lora_dpo_distributed.py @@ -141,8 +141,8 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -470,10 +470,9 @@ def save_checkpoint( - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - Checkpointer will save the merged weights, adapter weights and recipe state in - different checkpoint files. To correctly resume from training, the adapter weights - and recipe state must be provided along with the base model weights. + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. """ # final dict passed onto the checkpointer checkpoint_dict = {} @@ -530,6 +529,7 @@ def save_checkpoint( checkpoint_dict, epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, ) def concatenated_forward( diff --git a/recipes/lora_dpo_single_device.py b/recipes/lora_dpo_single_device.py index d0751640be..c6a8325e0f 100644 --- a/recipes/lora_dpo_single_device.py +++ b/recipes/lora_dpo_single_device.py @@ -103,8 +103,8 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -361,14 +361,15 @@ def save_checkpoint(self, epoch: int) -> None: - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - Checkpointer will save the merged weights, adapter weights and recipe state in - different checkpoint files. To correctly resume from training, the adapter weights - and recipe state must be provided along with the base model weights. + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. """ ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: + if intermediate_checkpoint: ckpt_dict.update( { utils.OPT_KEY: self._optimizer.state_dict(), @@ -396,10 +397,12 @@ def save_checkpoint(self, epoch: int) -> None: k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) } ckpt_dict.update({utils.ADAPTER_KEY: adapter_state_dict}) + self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, - intermediate_checkpoint=(epoch + 1 < self.total_epochs), + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, ) def concatenated_forward( diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 12fd7c5eb9..5bb3280c11 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -135,8 +135,8 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -537,10 +537,9 @@ def save_checkpoint( - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - Checkpointer will save the merged weights, adapter weights and recipe state in - different checkpoint files. To correctly resume from training, the adapter weights - and recipe state must be provided along with the base model weights. + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. """ # final dict passed onto the checkpointer checkpoint_dict = {} @@ -609,6 +608,7 @@ def save_checkpoint( checkpoint_dict, epoch=epoch, intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, ) def train(self) -> None: diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 4ac2d6876e..ce574ff7be 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -128,8 +128,8 @@ def __init__(self, cfg: DictConfig) -> None: self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch self.global_step = 0 - self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: @@ -468,14 +468,15 @@ def save_checkpoint(self, epoch: int) -> None: - Merged weights with key MODEL_KEY - Adapter weights with key ADAPTER_KEY - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - Checkpointer will save the merged weights, adapter weights and recipe state in - different checkpoint files. To correctly resume from training, the adapter weights - and recipe state must be provided along with the base model weights. + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. """ ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs # if training is in-progress, checkpoint the optimizer state as well - if epoch + 1 < self.total_epochs: + if intermediate_checkpoint: ckpt_dict.update( { utils.OPT_KEY: self._optimizer.state_dict(), @@ -514,10 +515,12 @@ def save_checkpoint(self, epoch: int) -> None: "peft_type": "LORA", } ckpt_dict.update({utils.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( ckpt_dict, epoch=epoch, - intermediate_checkpoint=(epoch + 1 < self.total_epochs), + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, ) def train(self) -> None: diff --git a/torchtune/utils/_checkpointing/_checkpointer.py b/torchtune/utils/_checkpointing/_checkpointer.py index 387664e120..7aac6879d7 100644 --- a/torchtune/utils/_checkpointing/_checkpointer.py +++ b/torchtune/utils/_checkpointing/_checkpointer.py @@ -281,6 +281,19 @@ def save_checkpoint( f"{os.path.getsize(output_path) / 1000**3:.2f} GB " f"saved to {output_path}" ) + else: + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." + ) + else: + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." + ) class FullModelHFCheckpointer(_CheckpointerInterface): @@ -595,6 +608,19 @@ def save_checkpoint( f"{os.path.getsize(output_path) / 1000**3:.2f} GB " f"saved to {output_path}" ) + else: + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." + ) + else: + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." + ) class FullModelMetaCheckpointer(_CheckpointerInterface): @@ -742,3 +768,16 @@ def save_checkpoint( f"{os.path.getsize(output_path) / 1000**3:.2f} GB " f"saved to {output_path}" ) + else: + logger.info("Saving final epoch checkpoint.") + if adapter_only: + logger.info( + "Please note that you have set adapter_only=True, so only adapter weights will be saved." + "You need to merge the adapter weights into your base model for further use. " + f"See {self.__class__.__name__}.save_checkpoint for more details." + ) + else: + logger.info( + "The full model checkpoint, including all weights and configurations, has been saved successfully." + "You can now use this checkpoint for further training or inference." + )