Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply rope on k earlier for efficiency #1558

Merged
merged 2 commits into from
Sep 16, 2024

Conversation

jackzhxng
Copy link
Contributor

@jackzhxng jackzhxng commented Sep 12, 2024

Context

By performing rotational embeddings early, we can squeeze out a bit more performance. Comparing finetuning statistics on Llama3.1-Instruct 8B w LoRA on GPU (full stats are in the "Appendix" section at the end of this pr description):

  • Tokens per second increased from 1246.51852/s ->1316.82931/s (+70 tokens per second)
  • Loss stayed about the same (1.48158 -> 1.48607)
  • Peak active memory stayed the same (18.14733 -> 18.14733)

Changelog

Does rope on an unexpanded k that is still [b, s_y, n_kv, h_d] instead of [b, s_y, n_h, h_d].

Test plan

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

No public API changes

Appendix

Stats from one epoch of finetuning Llama3.1-Instruct 8B prior to this PR:

1|25|Loss: 1.481581211090088: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [12:49<00:00, 30.78s/it]
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               global_step ▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
wandb:                      loss ▆▆▇▇▆▆▇▇▅▇▅▅█▅▆▆▅▅▄▃▃▃▂▂▁
wandb:                        lr ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▇▇▇▇██
wandb:        peak_memory_active ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:         peak_memory_alloc ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:      peak_memory_reserved ▁▁▃██████████████████████
wandb: tokens_per_second_per_gpu ▄▅▅▅▅▃▂▁▆▃▅█▃▆▁▄▂▅▂▃▄▅▃▂▅
wandb: 
wandb: Run summary:
wandb:               global_step 25
wandb:                      loss 1.48158
wandb:                        lr 7e-05
wandb:        peak_memory_active 18.14733
wandb:         peak_memory_alloc 18.14733
wandb:      peak_memory_reserved 19.43555
wandb: tokens_per_second_per_gpu 1246.51852
wandb: 
wandb: 🚀 View run lilac-puddle-14 at: https://wandb.ai/dvorjackz-meta/torchtune/runs/21e7roul
wandb: ⭐️ View project at: https://wandb.ai/dvorjackz-meta/torchtune
wandb: Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 1 other file(s)
wandb: Find logs at: /tmp/wandb/run-20240916_071220-21e7roul/logs

Stats from one epoch of finetuning Llama3.1-Instruct 8B after this PR:

1|25|Loss: 1.48606538772583: 100%|██████████████████████████████| 25/25 [12:34<00:00, 30.17s/it]
wandb:                                                                                
wandb: 
wandb: Run history:
wandb:               global_step ▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
wandb:                      loss ▆▆▇▆▆▆▇▇▅▇▅▅█▆▆▆▅▅▄▃▃▂▂▂▁
wandb:                        lr ▁▁▂▂▂▂▃▃▃▄▄▄▄▅▅▅▆▆▆▇▇▇▇██
wandb:        peak_memory_active ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:         peak_memory_alloc ▅▆▇█▆█▇▆▄██▆▅▆▁▆▆▇█▄▇▇▃▅▇
wandb:      peak_memory_reserved ▁▁▃██████████████████████
wandb: tokens_per_second_per_gpu ▅▇▄▆▆▅▅▂▆▄▇█▂▆▂▄▁▄▃▃▄▅▄▄▆
wandb: 
wandb: Run summary:
wandb:               global_step 25
wandb:                      loss 1.48607
wandb:                        lr 7e-05
wandb:        peak_memory_active 18.14733
wandb:         peak_memory_alloc 18.14733
wandb:      peak_memory_reserved 19.43555
wandb: tokens_per_second_per_gpu 1316.82931
wandb: 
wandb: 🚀 View run azure-forest-13 at: https://wandb.ai/dvorjackz-meta/torchtune/runs/uw4eja2j
wandb: ⭐️ View project at: https://wandb.ai/dvorjackz-meta/torchtune
wandb: Synced 4 W&B file(s), 0 media file(s), 3 artifact file(s) and 1 other file(s)
wandb: Find logs at: /tmp/wandb/run-20240916_064527-uw4eja2j/logs

Copy link

pytorch-bot bot commented Sep 12, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1558

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit dd3bf44 with merge base d7fae96 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 12, 2024
@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 73.32%. Comparing base (221031a) to head (75f6975).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1558      +/-   ##
==========================================
- Coverage   73.36%   73.32%   -0.04%     
==========================================
  Files         287      287              
  Lines       14142    14123      -19     
==========================================
- Hits        10375    10356      -19     
  Misses       3767     3767              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@felipemello1
Copy link
Contributor

felipemello1 commented Sep 13, 2024

Thanks for the PR! Nice catch! Do you think you can run it with vs without your changes and show the improvement + no change in loss? It will make it easy to approve the PR

If you use weights and biases, its probably the easiest way to take a screenshot (you can login with your gmail, get the token, pip install wandb, wandb login, insert your token)

then you can run your config like this:
tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device max_steps_per_epoch=25 epochs=1 metric_logger=torchtune.training.metric_logging.WandBLogger log_peak_memory_stats=True

you can paste the loss, tokens per second, and active memory

Copy link

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good to me and i will stamp but hopefully other folks in tune are ok with this

@ebsmothers
Copy link
Contributor

Yeah +1 to @felipemello1's suggestion. Otherwise this looks good to me, can stamp once you update the summary with the results

@kimishpatel
Copy link

@ebsmothers i tried to find your gh handle for tagging. now i know

@jackzhxng
Copy link
Contributor Author

@felipemello1 @ebsmothers updated the summary with before and after statistics from finetuning

@felipemello1
Copy link
Contributor

cool! thanks! Lets merge it after tests are done

@jackzhxng
Copy link
Contributor Author

@felipemello1 need another approval since there have been changes since Kimish's approval 👀

@felipemello1 felipemello1 merged commit bc2c013 into pytorch:main Sep 16, 2024
17 checks passed
@jackzhxng jackzhxng deleted the jackxz/rewrite-attention branch September 16, 2024 18:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants