Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new get_model_state_dict api for save_pretrained peft model #629

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llama_recipes/model_checkpointing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from llama_recipes.model_checkpointing.checkpoint_handler import (
load_model_checkpoint,
save_model_checkpoint,
save_peft_checkpoint,
load_optimizer_checkpoint,
save_optimizer_checkpoint,
save_model_and_optimizer_sharded,
Expand Down
11 changes: 10 additions & 1 deletion src/llama_recipes/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)


from torch.distributed.checkpoint.state_dict import get_model_state_dict, StateDictOptions
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
import torch.distributed._shard.checkpoint as dist_cp
import torch.distributed as dist
Expand Down Expand Up @@ -264,4 +265,12 @@ def load_sharded_model_single_gpu(model,model_path):
model.load_state_dict(state_dict["model"])

print(f"Sharded state checkpoint loaded from {model_path}")
return model
return model

def save_peft_checkpoint(model, model_path):
"""save_pretrained peft model"""

options = StateDictOptions(full_state_dict=True, cpu_offload=True)

state_dict = get_model_state_dict(model, options=options)
model.save_pretrained(model_path, state_dict=state_dict)
4 changes: 2 additions & 2 deletions src/llama_recipes/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import json


from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_peft_checkpoint
from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
from llama_recipes.utils.memory_utils import MemoryTrace
from accelerate.utils import is_xpu_available, is_ccl_available
Expand Down Expand Up @@ -235,7 +235,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
print(f"we are about to save the PEFT modules")
else:
print(f"we are about to save the PEFT modules")
model.save_pretrained(train_config.output_dir)
save_peft_checkpoint(model, train_config.output_dir)
if train_config.enable_fsdp:
if rank==0:
print(f"PEFT modules are saved in {train_config.output_dir} directory")
Expand Down
Loading