-
Notifications
You must be signed in to change notification settings - Fork 530
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Apply rope on k earlier for efficiency #1558
Apply rope on k earlier for efficiency #1558
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1558
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit dd3bf44 with merge base d7fae96 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1558 +/- ##
==========================================
- Coverage 73.36% 73.32% -0.04%
==========================================
Files 287 287
Lines 14142 14123 -19
==========================================
- Hits 10375 10356 -19
Misses 3767 3767 ☔ View full report in Codecov by Sentry. |
Thanks for the PR! Nice catch! Do you think you can run it with vs without your changes and show the improvement + no change in loss? It will make it easy to approve the PR If you use weights and biases, its probably the easiest way to take a screenshot (you can login with your gmail, get the token, pip install wandb, wandb login, insert your token) then you can run your config like this: you can paste the loss, tokens per second, and active memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me and i will stamp but hopefully other folks in tune are ok with this
Yeah +1 to @felipemello1's suggestion. Otherwise this looks good to me, can stamp once you update the summary with the results |
@ebsmothers i tried to find your gh handle for tagging. now i know |
@felipemello1 @ebsmothers updated the summary with before and after statistics from finetuning |
cool! thanks! Lets merge it after tests are done |
@felipemello1 need another approval since there have been changes since Kimish's approval 👀 |
Context
By performing rotational embeddings early, we can squeeze out a bit more performance. Comparing finetuning statistics on Llama3.1-Instruct 8B w LoRA on GPU (full stats are in the "Appendix" section at the end of this pr description):
1246.51852/s
->1316.82931/s
(+70 tokens per second)1.48158
->1.48607
)18.14733
->18.14733
)Changelog
Does rope on an unexpanded k that is still
[b, s_y, n_kv, h_d]
instead of[b, s_y, n_h, h_d]
.Test plan
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
No public API changes
Appendix
Stats from one epoch of finetuning Llama3.1-Instruct 8B prior to this PR:
Stats from one epoch of finetuning Llama3.1-Instruct 8B after this PR: