Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use new get_model_state_dict api for save_pretrained peft model #629

Merged
merged 1 commit into from
Aug 13, 2024

Conversation

mreso
Copy link
Contributor

@mreso mreso commented Aug 13, 2024

What does this PR do?

This PR updated the way we save the checkpoint in peft by using the new DCP api to avoid OOM https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html

Fixes # (issue)
#626

Feature/Issue validation/testing

Lora fine tune a model and load for inference

  • Test A
    CUDA_VISIBLE_DEVICES=0,1,4,5 torchrun --nnodes 1 --nproc_per_node 4 recipes/quickstart/finetuning/finetuning.py --enable_fsdp --model_name meta-llama/Meta-Llama-3.1-70B-Instruct --use_peft --peft_method lora --output_dir ../llama_output/ --run_validation --save_model --samsum_dataset.trust_remote_code=True --context_length 2048 --max_train_step 1 --max_eval_step 1 cd recipes/quickstart/inference/local_inference cat samsum_prompt.txt | python inference.py --model_name meta-llama/Meta-Llama-3.1-70B-Instruct --peft_model ~/llama_output/
    Logs for Test A training
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]--> applying fsdp activation checkpointing...
--> Training Set Length = 14732
--> Validation Set Length = 818
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14732/14732 [00:04<00:00, 3515.63it/s]
Preprocessing dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 818/818 [00:00<00:00, 3396.24it/s]
--> Num of Validation Set Batches loaded = 17
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/cuda/memory.py:343: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                                                                                                                                                                           | 0/79 [00:00<?, ?it/s]NCCL version 2.20.5+cuda12.4
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:295: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/utils/checkpoint.py:1399: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.
  with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), recompute_context:  # type: ignore[attr-defined]
Training Epoch: 1/3, step 0/79 completed (loss: 1.2386412620544434):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:27<36:09, 27.81s/it]max training steps reached, stopping training, total train steps finished:  1
Training Epoch: 1/3, step 0/79 completed (loss: 1.2386412620544434):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:27<36:16, 27.91s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.3376933336257935):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:37<49:19, 37.95s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4895133972167969):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:37<49:12, 37.85s/it]
Training Epoch: 1/3, step 0/79 completed (loss: 1.4058910608291626):   1%|███▎                                                                                                                                                                                                                                                             | 1/79 [00:40<52:39, 40.51s/it]
Max CUDA memory allocated was 62 GB
Max CUDA memory reserved was 67 GB
Peak active CUDA memory was 62 GB
CUDA Malloc retries : 0
CPU Total Peak Memory consumed during the train (max): 7 GB
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:02<00:33,  2.10s/it]max eval steps reached, stopping evaluation, total_eval_steps:  1
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:02<00:35,  2.25s/it]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:02<00:35,  2.24s/it]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:02<00:35,  2.23s/it]
evaluating Epoch:   6%|██████████████████                                                                                                                                                                                                                                                                                                  | 1/17 [00:02<00:36,  2.31s/it]
 eval_ppl=tensor(1.0827, device='cuda:0') eval_epoch_loss=tensor(0.0795, device='cuda:0')
we are about to save the PEFT modules
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/home/mreso/.conda/envs/llama/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:689: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
PEFT modules are saved in ../llama_output/ directory
best eval loss on epoch 1 is 0.07948913425207138
Epoch 1: train_perplexity=1.0175, train_epoch_loss=0.0173, epoch time 29.114889188029338s
Key: avg_train_prep, Value: 1.017466425895691
Key: avg_train_loss, Value: 0.017315631732344627
Key: avg_eval_prep, Value: 1.0827337503433228
Key: avg_eval_loss, Value: 0.07948913425207138
Key: avg_epoch_time, Value: 29.114889188029338
Key: avg_checkpoint_time, Value: 192.18921023100847

Logs for Test A inference

---

Summary:
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
the inference time is 24057.521794049535 ms
User input and model output deemed safe.
Model output:
Summarize this dialog:

A: Hi Tom, are you busy tomorrow’s afternoon?

B: I’m pretty sure I am. What’s up?

A: Can you go with me to the animal shelter?.

B: What do you want to do?

A: I want to get a puppy for my son.

B: That will make him so happy.

A: Yeah, we’ve discussed it many times. I think he’s ready now.

B: That’s good. Raising a dog is a tough issue. Like having a baby ;-)

A: I'll get him one of those little dogs.

B: One that won't grow up too big;-)

A: And eat too much;-))

B: Do you know which one he would like?

A: Oh, yes, I took him there last Monday. He showed me one that he really liked.

B: I bet you had to drag him away.

A: He wanted to take it home right away ;-).

B: I wonder what he'll name it.

A: He said he’d name it after his dead hamster – Lemmy  - he's  a great Motorhead fan :-)))

---

Summary: Tom's friend wants to go to an animal shelter tomorrow to pick out a puppy and asked Tom to accompany him. The friend's son has been wanting a puppy for a while and the friend thinks he's ready for the responsibilities. Tom agrees to go and teases his friend about the upcoming challenges of dog ownership. The friend has already taken his son to the shelter and the son has found a puppy he wants, and plans to name it after a deceased pet.

[This appears to be only a

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Thanks for contributing 🎉!

Copy link
Contributor

@wukaixingxp wukaixingxp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this fix, LGTM!

@mreso mreso merged commit eca5265 into main Aug 13, 2024
3 checks passed
@mreso mreso deleted the fix/fsdp_model_state_oom branch August 13, 2024 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants